xref: /llvm-project/llvm/lib/Analysis/DemandedBits.cpp (revision 236fda550d36d35a00785938c3e38b0f402aeda6)
1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a demanded bits analysis. A demanded bit is one that
10 // contributes to a result; bits that are not demanded can be either zero or
11 // one without affecting control or data flow. For example in this sequence:
12 //
13 //   %1 = add i32 %x, %y
14 //   %2 = trunc i32 %1 to i16
15 //
16 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
17 // trunc.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Analysis/DemandedBits.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/InstIterator.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Use.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/KnownBits.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <algorithm>
41 #include <cstdint>
42 
43 using namespace llvm;
44 using namespace llvm::PatternMatch;
45 
46 #define DEBUG_TYPE "demanded-bits"
47 
48 static bool isAlwaysLive(Instruction *I) {
49   return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
50          I->mayHaveSideEffects();
51 }
52 
53 void DemandedBits::determineLiveOperandBits(
54     const Instruction *UserI, const Value *Val, unsigned OperandNo,
55     const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
56     bool &KnownBitsComputed) {
57   unsigned BitWidth = AB.getBitWidth();
58 
59   // We're called once per operand, but for some instructions, we need to
60   // compute known bits of both operands in order to determine the live bits of
61   // either (when both operands are instructions themselves). We don't,
62   // however, want to do this twice, so we cache the result in APInts that live
63   // in the caller. For the two-relevant-operands case, both operand values are
64   // provided here.
65   auto ComputeKnownBits =
66       [&](unsigned BitWidth, const Value *V1, const Value *V2) {
67         if (KnownBitsComputed)
68           return;
69         KnownBitsComputed = true;
70 
71         const DataLayout &DL = UserI->getDataLayout();
72         Known = KnownBits(BitWidth);
73         computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
74 
75         if (V2) {
76           Known2 = KnownBits(BitWidth);
77           computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
78         }
79       };
80 
81   switch (UserI->getOpcode()) {
82   default: break;
83   case Instruction::Call:
84   case Instruction::Invoke:
85     if (const auto *II = dyn_cast<IntrinsicInst>(UserI)) {
86       switch (II->getIntrinsicID()) {
87       default: break;
88       case Intrinsic::bswap:
89         // The alive bits of the input are the swapped alive bits of
90         // the output.
91         AB = AOut.byteSwap();
92         break;
93       case Intrinsic::bitreverse:
94         // The alive bits of the input are the reversed alive bits of
95         // the output.
96         AB = AOut.reverseBits();
97         break;
98       case Intrinsic::ctlz:
99         if (OperandNo == 0) {
100           // We need some output bits, so we need all bits of the
101           // input to the left of, and including, the leftmost bit
102           // known to be one.
103           ComputeKnownBits(BitWidth, Val, nullptr);
104           AB = APInt::getHighBitsSet(BitWidth,
105                  std::min(BitWidth, Known.countMaxLeadingZeros()+1));
106         }
107         break;
108       case Intrinsic::cttz:
109         if (OperandNo == 0) {
110           // We need some output bits, so we need all bits of the
111           // input to the right of, and including, the rightmost bit
112           // known to be one.
113           ComputeKnownBits(BitWidth, Val, nullptr);
114           AB = APInt::getLowBitsSet(BitWidth,
115                  std::min(BitWidth, Known.countMaxTrailingZeros()+1));
116         }
117         break;
118       case Intrinsic::fshl:
119       case Intrinsic::fshr: {
120         const APInt *SA;
121         if (OperandNo == 2) {
122           // Shift amount is modulo the bitwidth. For powers of two we have
123           // SA % BW == SA & (BW - 1).
124           if (isPowerOf2_32(BitWidth))
125             AB = BitWidth - 1;
126         } else if (match(II->getOperand(2), m_APInt(SA))) {
127           // Normalize to funnel shift left. APInt shifts of BitWidth are well-
128           // defined, so no need to special-case zero shifts here.
129           uint64_t ShiftAmt = SA->urem(BitWidth);
130           if (II->getIntrinsicID() == Intrinsic::fshr)
131             ShiftAmt = BitWidth - ShiftAmt;
132 
133           if (OperandNo == 0)
134             AB = AOut.lshr(ShiftAmt);
135           else if (OperandNo == 1)
136             AB = AOut.shl(BitWidth - ShiftAmt);
137         }
138         break;
139       }
140       case Intrinsic::umax:
141       case Intrinsic::umin:
142       case Intrinsic::smax:
143       case Intrinsic::smin:
144         // If low bits of result are not demanded, they are also not demanded
145         // for the min/max operands.
146         AB = APInt::getBitsSetFrom(BitWidth, AOut.countr_zero());
147         break;
148       }
149     }
150     break;
151   case Instruction::Add:
152     if (AOut.isMask()) {
153       AB = AOut;
154     } else {
155       ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
156       AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
157     }
158     break;
159   case Instruction::Sub:
160     if (AOut.isMask()) {
161       AB = AOut;
162     } else {
163       ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
164       AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
165     }
166     break;
167   case Instruction::Mul:
168     // Find the highest live output bit. We don't need any more input
169     // bits than that (adds, and thus subtracts, ripple only to the
170     // left).
171     AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
172     break;
173   case Instruction::Shl:
174     if (OperandNo == 0) {
175       const APInt *ShiftAmtC;
176       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
177         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
178         AB = AOut.lshr(ShiftAmt);
179 
180         // If the shift is nuw/nsw, then the high bits are not dead
181         // (because we've promised that they *must* be zero).
182         const auto *S = cast<ShlOperator>(UserI);
183         if (S->hasNoSignedWrap())
184           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
185         else if (S->hasNoUnsignedWrap())
186           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
187       }
188     }
189     break;
190   case Instruction::LShr:
191     if (OperandNo == 0) {
192       const APInt *ShiftAmtC;
193       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
194         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
195         AB = AOut.shl(ShiftAmt);
196 
197         // If the shift is exact, then the low bits are not dead
198         // (they must be zero).
199         if (cast<LShrOperator>(UserI)->isExact())
200           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
201       }
202     }
203     break;
204   case Instruction::AShr:
205     if (OperandNo == 0) {
206       const APInt *ShiftAmtC;
207       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
208         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
209         AB = AOut.shl(ShiftAmt);
210         // Because the high input bit is replicated into the
211         // high-order bits of the result, if we need any of those
212         // bits, then we must keep the highest input bit.
213         if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
214             .getBoolValue())
215           AB.setSignBit();
216 
217         // If the shift is exact, then the low bits are not dead
218         // (they must be zero).
219         if (cast<AShrOperator>(UserI)->isExact())
220           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
221       }
222     }
223     break;
224   case Instruction::And:
225     AB = AOut;
226 
227     // For bits that are known zero, the corresponding bits in the
228     // other operand are dead (unless they're both zero, in which
229     // case they can't both be dead, so just mark the LHS bits as
230     // dead).
231     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
232     if (OperandNo == 0)
233       AB &= ~Known2.Zero;
234     else
235       AB &= ~(Known.Zero & ~Known2.Zero);
236     break;
237   case Instruction::Or:
238     AB = AOut;
239 
240     // For bits that are known one, the corresponding bits in the
241     // other operand are dead (unless they're both one, in which
242     // case they can't both be dead, so just mark the LHS bits as
243     // dead).
244     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
245     if (OperandNo == 0)
246       AB &= ~Known2.One;
247     else
248       AB &= ~(Known.One & ~Known2.One);
249     break;
250   case Instruction::Xor:
251   case Instruction::PHI:
252     AB = AOut;
253     break;
254   case Instruction::Trunc:
255     AB = AOut.zext(BitWidth);
256     break;
257   case Instruction::ZExt:
258     AB = AOut.trunc(BitWidth);
259     break;
260   case Instruction::SExt:
261     AB = AOut.trunc(BitWidth);
262     // Because the high input bit is replicated into the
263     // high-order bits of the result, if we need any of those
264     // bits, then we must keep the highest input bit.
265     if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
266                                       AOut.getBitWidth() - BitWidth))
267         .getBoolValue())
268       AB.setSignBit();
269     break;
270   case Instruction::Select:
271     if (OperandNo != 0)
272       AB = AOut;
273     break;
274   case Instruction::ExtractElement:
275     if (OperandNo == 0)
276       AB = AOut;
277     break;
278   case Instruction::InsertElement:
279   case Instruction::ShuffleVector:
280     if (OperandNo == 0 || OperandNo == 1)
281       AB = AOut;
282     break;
283   }
284 }
285 
286 void DemandedBits::performAnalysis() {
287   if (Analyzed)
288     // Analysis already completed for this function.
289     return;
290   Analyzed = true;
291 
292   Visited.clear();
293   AliveBits.clear();
294   DeadUses.clear();
295 
296   SmallSetVector<Instruction*, 16> Worklist;
297 
298   // Collect the set of "root" instructions that are known live.
299   for (Instruction &I : instructions(F)) {
300     if (!isAlwaysLive(&I))
301       continue;
302 
303     LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
304     // For integer-valued instructions, set up an initial empty set of alive
305     // bits and add the instruction to the work list. For other instructions
306     // add their operands to the work list (for integer values operands, mark
307     // all bits as live).
308     Type *T = I.getType();
309     if (T->isIntOrIntVectorTy()) {
310       if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
311         Worklist.insert(&I);
312 
313       continue;
314     }
315 
316     // Non-integer-typed instructions...
317     for (Use &OI : I.operands()) {
318       if (auto *J = dyn_cast<Instruction>(OI)) {
319         Type *T = J->getType();
320         if (T->isIntOrIntVectorTy())
321           AliveBits[J] = APInt::getAllOnes(T->getScalarSizeInBits());
322         else
323           Visited.insert(J);
324         Worklist.insert(J);
325       }
326     }
327     // To save memory, we don't add I to the Visited set here. Instead, we
328     // check isAlwaysLive on every instruction when searching for dead
329     // instructions later (we need to check isAlwaysLive for the
330     // integer-typed instructions anyway).
331   }
332 
333   // Propagate liveness backwards to operands.
334   while (!Worklist.empty()) {
335     Instruction *UserI = Worklist.pop_back_val();
336 
337     LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
338     APInt AOut;
339     bool InputIsKnownDead = false;
340     if (UserI->getType()->isIntOrIntVectorTy()) {
341       AOut = AliveBits[UserI];
342       LLVM_DEBUG(dbgs() << " Alive Out: 0x"
343                         << Twine::utohexstr(AOut.getLimitedValue()));
344 
345       // If all bits of the output are dead, then all bits of the input
346       // are also dead.
347       InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
348     }
349     LLVM_DEBUG(dbgs() << "\n");
350 
351     KnownBits Known, Known2;
352     bool KnownBitsComputed = false;
353     // Compute the set of alive bits for each operand. These are anded into the
354     // existing set, if any, and if that changes the set of alive bits, the
355     // operand is added to the work-list.
356     for (Use &OI : UserI->operands()) {
357       // We also want to detect dead uses of arguments, but will only store
358       // demanded bits for instructions.
359       auto *I = dyn_cast<Instruction>(OI);
360       if (!I && !isa<Argument>(OI))
361         continue;
362 
363       Type *T = OI->getType();
364       if (T->isIntOrIntVectorTy()) {
365         unsigned BitWidth = T->getScalarSizeInBits();
366         APInt AB = APInt::getAllOnes(BitWidth);
367         if (InputIsKnownDead) {
368           AB = APInt(BitWidth, 0);
369         } else {
370           // Bits of each operand that are used to compute alive bits of the
371           // output are alive, all others are dead.
372           determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
373                                    Known, Known2, KnownBitsComputed);
374 
375           // Keep track of uses which have no demanded bits.
376           if (AB.isZero())
377             DeadUses.insert(&OI);
378           else
379             DeadUses.erase(&OI);
380         }
381 
382         if (I) {
383           // If we've added to the set of alive bits (or the operand has not
384           // been previously visited), then re-queue the operand to be visited
385           // again.
386           auto Res = AliveBits.try_emplace(I);
387           if (Res.second || (AB |= Res.first->second) != Res.first->second) {
388             Res.first->second = std::move(AB);
389             Worklist.insert(I);
390           }
391         }
392       } else if (I && Visited.insert(I).second) {
393         Worklist.insert(I);
394       }
395     }
396   }
397 }
398 
399 APInt DemandedBits::getDemandedBits(Instruction *I) {
400   performAnalysis();
401 
402   auto Found = AliveBits.find(I);
403   if (Found != AliveBits.end())
404     return Found->second;
405 
406   const DataLayout &DL = I->getDataLayout();
407   return APInt::getAllOnes(DL.getTypeSizeInBits(I->getType()->getScalarType()));
408 }
409 
410 APInt DemandedBits::getDemandedBits(Use *U) {
411   Type *T = (*U)->getType();
412   auto *UserI = cast<Instruction>(U->getUser());
413   const DataLayout &DL = UserI->getDataLayout();
414   unsigned BitWidth = DL.getTypeSizeInBits(T->getScalarType());
415 
416   // We only track integer uses, everything else produces a mask with all bits
417   // set
418   if (!T->isIntOrIntVectorTy())
419     return APInt::getAllOnes(BitWidth);
420 
421   if (isUseDead(U))
422     return APInt(BitWidth, 0);
423 
424   performAnalysis();
425 
426   APInt AOut = getDemandedBits(UserI);
427   APInt AB = APInt::getAllOnes(BitWidth);
428   KnownBits Known, Known2;
429   bool KnownBitsComputed = false;
430 
431   determineLiveOperandBits(UserI, *U, U->getOperandNo(), AOut, AB, Known,
432                            Known2, KnownBitsComputed);
433 
434   return AB;
435 }
436 
437 bool DemandedBits::isInstructionDead(Instruction *I) {
438   performAnalysis();
439 
440   return !Visited.count(I) && !AliveBits.contains(I) && !isAlwaysLive(I);
441 }
442 
443 bool DemandedBits::isUseDead(Use *U) {
444   // We only track integer uses, everything else is assumed live.
445   if (!(*U)->getType()->isIntOrIntVectorTy())
446     return false;
447 
448   // Uses by always-live instructions are never dead.
449   auto *UserI = cast<Instruction>(U->getUser());
450   if (isAlwaysLive(UserI))
451     return false;
452 
453   performAnalysis();
454   if (DeadUses.count(U))
455     return true;
456 
457   // If no output bits are demanded, no input bits are demanded and the use
458   // is dead. These uses might not be explicitly present in the DeadUses map.
459   if (UserI->getType()->isIntOrIntVectorTy()) {
460     auto Found = AliveBits.find(UserI);
461     if (Found != AliveBits.end() && Found->second.isZero())
462       return true;
463   }
464 
465   return false;
466 }
467 
468 void DemandedBits::print(raw_ostream &OS) {
469   auto PrintDB = [&](const Instruction *I, const APInt &A, Value *V = nullptr) {
470     OS << "DemandedBits: 0x" << Twine::utohexstr(A.getLimitedValue())
471        << " for ";
472     if (V) {
473       V->printAsOperand(OS, false);
474       OS << " in ";
475     }
476     OS << *I << '\n';
477   };
478 
479   OS << "Printing analysis 'Demanded Bits Analysis' for function '" << F.getName() << "':\n";
480   performAnalysis();
481   for (auto &KV : AliveBits) {
482     Instruction *I = KV.first;
483     PrintDB(I, KV.second);
484 
485     for (Use &OI : I->operands()) {
486       PrintDB(I, getDemandedBits(&OI), OI);
487     }
488   }
489 }
490 
491 static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
492                                               const APInt &AOut,
493                                               const KnownBits &LHS,
494                                               const KnownBits &RHS,
495                                               bool CarryZero, bool CarryOne) {
496   assert(!(CarryZero && CarryOne) &&
497          "Carry can't be zero and one at the same time");
498 
499   // The following check should be done by the caller, as it also indicates
500   // that LHS and RHS don't need to be computed.
501   //
502   // if (AOut.isMask())
503   //   return AOut;
504 
505   // Boundary bits' carry out is unaffected by their carry in.
506   APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
507 
508   // First, the alive carry bits are determined from the alive output bits:
509   // Let demand ripple to the right but only up to any set bit in Bound.
510   //   AOut         = -1----
511   //   Bound        = ----1-
512   //   ACarry&~AOut = --111-
513   APInt RBound = Bound.reverseBits();
514   APInt RAOut = AOut.reverseBits();
515   APInt RProp = RAOut + (RAOut | ~RBound);
516   APInt RACarry = RProp ^ ~RBound;
517   APInt ACarry = RACarry.reverseBits();
518 
519   // Then, the alive input bits are determined from the alive carry bits:
520   APInt NeededToMaintainCarryZero;
521   APInt NeededToMaintainCarryOne;
522   if (OperandNo == 0) {
523     NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
524     NeededToMaintainCarryOne = LHS.One | ~RHS.One;
525   } else {
526     NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
527     NeededToMaintainCarryOne = RHS.One | ~LHS.One;
528   }
529 
530   // As in computeForAddCarry
531   APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
532   APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
533 
534   // The below is simplified from
535   //
536   // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
537   // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
538   // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
539   //
540   // APInt NeededToMaintainCarry =
541   //   (CarryKnownZero & NeededToMaintainCarryZero) |
542   //   (CarryKnownOne  & NeededToMaintainCarryOne) |
543   //   CarryUnknown;
544 
545   APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
546                                 (PossibleSumOne | NeededToMaintainCarryOne);
547 
548   APInt AB = AOut | (ACarry & NeededToMaintainCarry);
549   return AB;
550 }
551 
552 APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
553                                                 const APInt &AOut,
554                                                 const KnownBits &LHS,
555                                                 const KnownBits &RHS) {
556   return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
557                                           false);
558 }
559 
560 APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
561                                                 const APInt &AOut,
562                                                 const KnownBits &LHS,
563                                                 const KnownBits &RHS) {
564   KnownBits NRHS;
565   NRHS.Zero = RHS.One;
566   NRHS.One = RHS.Zero;
567   return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
568                                           true);
569 }
570 
571 AnalysisKey DemandedBitsAnalysis::Key;
572 
573 DemandedBits DemandedBitsAnalysis::run(Function &F,
574                                              FunctionAnalysisManager &AM) {
575   auto &AC = AM.getResult<AssumptionAnalysis>(F);
576   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
577   return DemandedBits(F, AC, DT);
578 }
579 
580 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
581                                                FunctionAnalysisManager &AM) {
582   AM.getResult<DemandedBitsAnalysis>(F).print(OS);
583   return PreservedAnalyses::all();
584 }
585