1 //===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file holds routines to help analyse compare instructions 10 // and fold them into constants or other compare instructions 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/CmpInstAnalysis.h" 15 #include "llvm/IR/Constants.h" 16 #include "llvm/IR/Instructions.h" 17 #include "llvm/IR/PatternMatch.h" 18 19 using namespace llvm; 20 21 unsigned llvm::getICmpCode(CmpInst::Predicate Pred) { 22 switch (Pred) { 23 // False -> 0 24 case ICmpInst::ICMP_UGT: return 1; // 001 25 case ICmpInst::ICMP_SGT: return 1; // 001 26 case ICmpInst::ICMP_EQ: return 2; // 010 27 case ICmpInst::ICMP_UGE: return 3; // 011 28 case ICmpInst::ICMP_SGE: return 3; // 011 29 case ICmpInst::ICMP_ULT: return 4; // 100 30 case ICmpInst::ICMP_SLT: return 4; // 100 31 case ICmpInst::ICMP_NE: return 5; // 101 32 case ICmpInst::ICMP_ULE: return 6; // 110 33 case ICmpInst::ICMP_SLE: return 6; // 110 34 // True -> 7 35 default: 36 llvm_unreachable("Invalid ICmp predicate!"); 37 } 38 } 39 40 Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy, 41 CmpInst::Predicate &Pred) { 42 switch (Code) { 43 default: llvm_unreachable("Illegal ICmp code!"); 44 case 0: // False. 45 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); 46 case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; 47 case 2: Pred = ICmpInst::ICMP_EQ; break; 48 case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; 49 case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; 50 case 5: Pred = ICmpInst::ICMP_NE; break; 51 case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; 52 case 7: // True. 53 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); 54 } 55 return nullptr; 56 } 57 58 bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) { 59 return (CmpInst::isSigned(P1) == CmpInst::isSigned(P2)) || 60 (CmpInst::isSigned(P1) && ICmpInst::isEquality(P2)) || 61 (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1)); 62 } 63 64 Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, 65 CmpInst::Predicate &Pred) { 66 Pred = static_cast<FCmpInst::Predicate>(Code); 67 assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && 68 "Unexpected FCmp predicate!"); 69 if (Pred == FCmpInst::FCMP_FALSE) 70 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); 71 if (Pred == FCmpInst::FCMP_TRUE) 72 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); 73 return nullptr; 74 } 75 76 std::optional<DecomposedBitTest> 77 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, 78 bool LookThruTrunc, bool AllowNonZeroC) { 79 using namespace PatternMatch; 80 81 const APInt *OrigC; 82 if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC))) 83 return std::nullopt; 84 85 bool Inverted = false; 86 if (ICmpInst::isGT(Pred) || ICmpInst::isGE(Pred)) { 87 Inverted = true; 88 Pred = ICmpInst::getInversePredicate(Pred); 89 } 90 91 APInt C = *OrigC; 92 if (ICmpInst::isLE(Pred)) { 93 if (ICmpInst::isSigned(Pred) ? C.isMaxSignedValue() : C.isMaxValue()) 94 return std::nullopt; 95 ++C; 96 Pred = ICmpInst::getStrictPredicate(Pred); 97 } 98 99 DecomposedBitTest Result; 100 switch (Pred) { 101 default: 102 llvm_unreachable("Unexpected predicate"); 103 case ICmpInst::ICMP_SLT: { 104 // X < 0 is equivalent to (X & SignMask) != 0. 105 if (C.isZero()) { 106 Result.Mask = APInt::getSignMask(C.getBitWidth()); 107 Result.C = APInt::getZero(C.getBitWidth()); 108 Result.Pred = ICmpInst::ICMP_NE; 109 break; 110 } 111 112 APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth()); 113 if (FlippedSign.isPowerOf2()) { 114 // X s< 10000100 is equivalent to (X & 11111100 == 10000000) 115 Result.Mask = -FlippedSign; 116 Result.C = APInt::getSignMask(C.getBitWidth()); 117 Result.Pred = ICmpInst::ICMP_EQ; 118 break; 119 } 120 121 if (FlippedSign.isNegatedPowerOf2()) { 122 // X s< 01111100 is equivalent to (X & 11111100 != 01111100) 123 Result.Mask = FlippedSign; 124 Result.C = C; 125 Result.Pred = ICmpInst::ICMP_NE; 126 break; 127 } 128 129 return std::nullopt; 130 } 131 case ICmpInst::ICMP_ULT: 132 // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. 133 if (C.isPowerOf2()) { 134 Result.Mask = -C; 135 Result.C = APInt::getZero(C.getBitWidth()); 136 Result.Pred = ICmpInst::ICMP_EQ; 137 break; 138 } 139 140 // X u< 11111100 is equivalent to (X & 11111100 != 11111100) 141 if (C.isNegatedPowerOf2()) { 142 Result.Mask = C; 143 Result.C = C; 144 Result.Pred = ICmpInst::ICMP_NE; 145 break; 146 } 147 148 return std::nullopt; 149 } 150 151 if (!AllowNonZeroC && !Result.C.isZero()) 152 return std::nullopt; 153 154 if (Inverted) 155 Result.Pred = ICmpInst::getInversePredicate(Result.Pred); 156 157 Value *X; 158 if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) { 159 Result.X = X; 160 Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits()); 161 Result.C = Result.C.zext(X->getType()->getScalarSizeInBits()); 162 } else { 163 Result.X = LHS; 164 } 165 166 return Result; 167 } 168 169 std::optional<DecomposedBitTest> 170 llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { 171 if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) { 172 // Don't allow pointers. Splat vectors are fine. 173 if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy()) 174 return std::nullopt; 175 return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), 176 ICmp->getPredicate(), LookThruTrunc, 177 AllowNonZeroC); 178 } 179 180 return std::nullopt; 181 } 182