1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/APInt.h" 10 #include "llvm/Support/DivisionByConstantInfo.h" 11 #include "gtest/gtest.h" 12 13 using namespace llvm; 14 15 namespace { 16 17 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) { 18 APInt N(Bits, 0); 19 do { 20 TestFn(N); 21 } while (++N != 0); 22 } 23 24 APInt MULHS(APInt X, APInt Y) { 25 unsigned Bits = X.getBitWidth(); 26 unsigned WideBits = 2 * Bits; 27 return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); 28 } 29 30 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, 31 SignedDivisionByConstantInfo Magics) { 32 unsigned Bits = Numerator.getBitWidth(); 33 34 APInt Factor(Bits, 0); 35 APInt ShiftMask(Bits, -1, true); 36 if (Divisor.isOne() || Divisor.isAllOnes()) { 37 // If d is +1/-1, we just multiply the numerator by +1/-1. 38 Factor = Divisor.getSExtValue(); 39 Magics.Magic = 0; 40 Magics.ShiftAmount = 0; 41 ShiftMask = 0; 42 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { 43 // If d > 0 and m < 0, add the numerator. 44 Factor = 1; 45 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { 46 // If d < 0 and m > 0, subtract the numerator. 47 Factor = -1; 48 } 49 50 // Multiply the numerator by the magic value. 51 APInt Q = MULHS(Numerator, Magics.Magic); 52 53 // (Optionally) Add/subtract the numerator using Factor. 54 Factor = Numerator * Factor; 55 Q = Q + Factor; 56 57 // Shift right algebraic by shift value. 58 Q = Q.ashr(Magics.ShiftAmount); 59 60 // Extract the sign bit, mask it and add it to the quotient. 61 unsigned SignShift = Bits - 1; 62 APInt T = Q.lshr(SignShift); 63 T = T & ShiftMask; 64 return Q + T; 65 } 66 67 TEST(SignedDivisionByConstantTest, Test) { 68 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 69 if (Bits < 3) 70 continue; // Not supported by `SignedDivisionByConstantInfo::get()`. 71 if (Bits > 12) 72 continue; // Unreasonably slow. 73 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 74 if (Divisor.isZero()) 75 return; // Division by zero is undefined behavior. 76 SignedDivisionByConstantInfo Magics; 77 if (!(Divisor.isOne() || Divisor.isAllOnes())) 78 Magics = SignedDivisionByConstantInfo::get(Divisor); 79 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 80 if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) 81 return; // Overflow is undefined behavior. 82 APInt NativeResult = Numerator.sdiv(Divisor); 83 APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); 84 ASSERT_EQ(MagicResult, NativeResult) 85 << " ... given the operation: srem i" << Bits << " " << Numerator 86 << ", " << Divisor; 87 }); 88 }); 89 } 90 } 91 92 APInt MULHU(APInt X, APInt Y) { 93 unsigned Bits = X.getBitWidth(); 94 unsigned WideBits = 2 * Bits; 95 return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); 96 } 97 98 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, 99 bool LZOptimization, 100 bool AllowEvenDivisorOptimization, bool ForceNPQ, 101 UnsignedDivisionByConstantInfo Magics) { 102 assert(!Divisor.isOne() && "Division by 1 is not supported using Magic."); 103 104 unsigned Bits = Numerator.getBitWidth(); 105 106 if (LZOptimization) { 107 unsigned LeadingZeros = Numerator.countl_zero(); 108 // Clip to the number of leading zeros in the divisor. 109 LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero()); 110 if (LeadingZeros > 0) { 111 Magics = UnsignedDivisionByConstantInfo::get( 112 Divisor, LeadingZeros, AllowEvenDivisorOptimization); 113 assert(!Magics.IsAdd && "Should use cheap fixup now"); 114 } 115 } 116 117 assert(Magics.PreShift < Divisor.getBitWidth() && 118 "We shouldn't generate an undefined shift!"); 119 assert(Magics.PostShift < Divisor.getBitWidth() && 120 "We shouldn't generate an undefined shift!"); 121 assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift"); 122 unsigned PreShift = Magics.PreShift; 123 unsigned PostShift = Magics.PostShift; 124 bool UseNPQ = Magics.IsAdd; 125 126 APInt NPQFactor = 127 UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits); 128 129 APInt Q = Numerator.lshr(PreShift); 130 131 // Multiply the numerator by the magic value. 132 Q = MULHU(Q, Magics.Magic); 133 134 if (UseNPQ || ForceNPQ) { 135 APInt NPQ = Numerator - Q; 136 137 // For vectors we might have a mix of non-NPQ/NPQ paths, so use 138 // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. 139 APInt NPQ_Scalar = NPQ.lshr(1); 140 (void)NPQ_Scalar; 141 NPQ = MULHU(NPQ, NPQFactor); 142 assert(!UseNPQ || NPQ == NPQ_Scalar); 143 144 Q = NPQ + Q; 145 } 146 147 Q = Q.lshr(PostShift); 148 149 return Q; 150 } 151 152 TEST(UnsignedDivisionByConstantTest, Test) { 153 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 154 if (Bits < 2) 155 continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. 156 if (Bits > 10) 157 continue; // Unreasonably slow. 158 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 159 if (Divisor.isZero()) 160 return; // Division by zero is undefined behavior. 161 if (Divisor.isOne()) 162 return; // Division by one is the numerator. 163 164 const UnsignedDivisionByConstantInfo Magics = 165 UnsignedDivisionByConstantInfo::get(Divisor); 166 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 167 APInt NativeResult = Numerator.udiv(Divisor); 168 for (bool LZOptimization : {true, false}) { 169 for (bool AllowEvenDivisorOptimization : {true, false}) { 170 for (bool ForceNPQ : {false, true}) { 171 APInt MagicResult = UnsignedDivideUsingMagic( 172 Numerator, Divisor, LZOptimization, 173 AllowEvenDivisorOptimization, ForceNPQ, Magics); 174 ASSERT_EQ(MagicResult, NativeResult) 175 << " ... given the operation: urem i" << Bits << " " 176 << Numerator << ", " << Divisor 177 << " (allow LZ optimization = " 178 << LZOptimization << ", allow even divisior optimization = " 179 << AllowEvenDivisorOptimization << ", force NPQ = " 180 << ForceNPQ << ")"; 181 } 182 } 183 } 184 }); 185 }); 186 } 187 } 188 189 } // end anonymous namespace 190