1066b492bSRoman Lebedev //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// 2066b492bSRoman Lebedev // 3066b492bSRoman Lebedev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4066b492bSRoman Lebedev // See https://llvm.org/LICENSE.txt for license information. 5066b492bSRoman Lebedev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6066b492bSRoman Lebedev // 7066b492bSRoman Lebedev //===----------------------------------------------------------------------===// 8066b492bSRoman Lebedev 9066b492bSRoman Lebedev #include "llvm/ADT/APInt.h" 10066b492bSRoman Lebedev #include "llvm/Support/DivisionByConstantInfo.h" 11066b492bSRoman Lebedev #include "gtest/gtest.h" 12066b492bSRoman Lebedev 13066b492bSRoman Lebedev using namespace llvm; 14066b492bSRoman Lebedev 15066b492bSRoman Lebedev namespace { 16066b492bSRoman Lebedev 17066b492bSRoman Lebedev template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) { 18066b492bSRoman Lebedev APInt N(Bits, 0); 19066b492bSRoman Lebedev do { 20066b492bSRoman Lebedev TestFn(N); 21066b492bSRoman Lebedev } while (++N != 0); 22066b492bSRoman Lebedev } 23066b492bSRoman Lebedev 24066b492bSRoman Lebedev APInt MULHS(APInt X, APInt Y) { 25066b492bSRoman Lebedev unsigned Bits = X.getBitWidth(); 26066b492bSRoman Lebedev unsigned WideBits = 2 * Bits; 27066b492bSRoman Lebedev return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); 28066b492bSRoman Lebedev } 29066b492bSRoman Lebedev 30066b492bSRoman Lebedev APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, 31066b492bSRoman Lebedev SignedDivisionByConstantInfo Magics) { 32066b492bSRoman Lebedev unsigned Bits = Numerator.getBitWidth(); 33066b492bSRoman Lebedev 34066b492bSRoman Lebedev APInt Factor(Bits, 0); 35*37e5319aSNikita Popov APInt ShiftMask(Bits, -1, true); 36066b492bSRoman Lebedev if (Divisor.isOne() || Divisor.isAllOnes()) { 37066b492bSRoman Lebedev // If d is +1/-1, we just multiply the numerator by +1/-1. 38066b492bSRoman Lebedev Factor = Divisor.getSExtValue(); 39066b492bSRoman Lebedev Magics.Magic = 0; 40066b492bSRoman Lebedev Magics.ShiftAmount = 0; 41066b492bSRoman Lebedev ShiftMask = 0; 42066b492bSRoman Lebedev } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { 43066b492bSRoman Lebedev // If d > 0 and m < 0, add the numerator. 44066b492bSRoman Lebedev Factor = 1; 45066b492bSRoman Lebedev } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { 46066b492bSRoman Lebedev // If d < 0 and m > 0, subtract the numerator. 47066b492bSRoman Lebedev Factor = -1; 48066b492bSRoman Lebedev } 49066b492bSRoman Lebedev 50066b492bSRoman Lebedev // Multiply the numerator by the magic value. 51066b492bSRoman Lebedev APInt Q = MULHS(Numerator, Magics.Magic); 52066b492bSRoman Lebedev 53066b492bSRoman Lebedev // (Optionally) Add/subtract the numerator using Factor. 54066b492bSRoman Lebedev Factor = Numerator * Factor; 55066b492bSRoman Lebedev Q = Q + Factor; 56066b492bSRoman Lebedev 57066b492bSRoman Lebedev // Shift right algebraic by shift value. 58066b492bSRoman Lebedev Q = Q.ashr(Magics.ShiftAmount); 59066b492bSRoman Lebedev 60066b492bSRoman Lebedev // Extract the sign bit, mask it and add it to the quotient. 61066b492bSRoman Lebedev unsigned SignShift = Bits - 1; 62066b492bSRoman Lebedev APInt T = Q.lshr(SignShift); 63066b492bSRoman Lebedev T = T & ShiftMask; 64066b492bSRoman Lebedev return Q + T; 65066b492bSRoman Lebedev } 66066b492bSRoman Lebedev 67066b492bSRoman Lebedev TEST(SignedDivisionByConstantTest, Test) { 68066b492bSRoman Lebedev for (unsigned Bits = 1; Bits <= 32; ++Bits) { 69066b492bSRoman Lebedev if (Bits < 3) 70066b492bSRoman Lebedev continue; // Not supported by `SignedDivisionByConstantInfo::get()`. 71066b492bSRoman Lebedev if (Bits > 12) 72066b492bSRoman Lebedev continue; // Unreasonably slow. 73066b492bSRoman Lebedev EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 74066b492bSRoman Lebedev if (Divisor.isZero()) 75066b492bSRoman Lebedev return; // Division by zero is undefined behavior. 76066b492bSRoman Lebedev SignedDivisionByConstantInfo Magics; 77066b492bSRoman Lebedev if (!(Divisor.isOne() || Divisor.isAllOnes())) 78066b492bSRoman Lebedev Magics = SignedDivisionByConstantInfo::get(Divisor); 79066b492bSRoman Lebedev EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 80066b492bSRoman Lebedev if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) 81066b492bSRoman Lebedev return; // Overflow is undefined behavior. 82066b492bSRoman Lebedev APInt NativeResult = Numerator.sdiv(Divisor); 83066b492bSRoman Lebedev APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); 84066b492bSRoman Lebedev ASSERT_EQ(MagicResult, NativeResult) 85066b492bSRoman Lebedev << " ... given the operation: srem i" << Bits << " " << Numerator 86066b492bSRoman Lebedev << ", " << Divisor; 87066b492bSRoman Lebedev }); 88066b492bSRoman Lebedev }); 89066b492bSRoman Lebedev } 90066b492bSRoman Lebedev } 91066b492bSRoman Lebedev 92066b492bSRoman Lebedev APInt MULHU(APInt X, APInt Y) { 93066b492bSRoman Lebedev unsigned Bits = X.getBitWidth(); 94066b492bSRoman Lebedev unsigned WideBits = 2 * Bits; 95066b492bSRoman Lebedev return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); 96066b492bSRoman Lebedev } 97066b492bSRoman Lebedev 988abd7008SCraig Topper APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, 998abd7008SCraig Topper bool LZOptimization, 100066b492bSRoman Lebedev bool AllowEvenDivisorOptimization, bool ForceNPQ, 101066b492bSRoman Lebedev UnsignedDivisionByConstantInfo Magics) { 1028bca60fbSCraig Topper assert(!Divisor.isOne() && "Division by 1 is not supported using Magic."); 1038bca60fbSCraig Topper 104066b492bSRoman Lebedev unsigned Bits = Numerator.getBitWidth(); 105066b492bSRoman Lebedev 1068bca60fbSCraig Topper if (LZOptimization) { 107f8f3db27SKazu Hirata unsigned LeadingZeros = Numerator.countl_zero(); 1088abd7008SCraig Topper // Clip to the number of leading zeros in the divisor. 109f8f3db27SKazu Hirata LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero()); 1108abd7008SCraig Topper if (LeadingZeros > 0) { 11184daed7fSCraig Topper Magics = UnsignedDivisionByConstantInfo::get( 11284daed7fSCraig Topper Divisor, LeadingZeros, AllowEvenDivisorOptimization); 1138abd7008SCraig Topper assert(!Magics.IsAdd && "Should use cheap fixup now"); 1148abd7008SCraig Topper } 1158abd7008SCraig Topper } 1168abd7008SCraig Topper 1173f749a5dSCraig Topper assert(Magics.PreShift < Divisor.getBitWidth() && 118066b492bSRoman Lebedev "We shouldn't generate an undefined shift!"); 1193f749a5dSCraig Topper assert(Magics.PostShift < Divisor.getBitWidth() && 120012afbbaSRoman Lebedev "We shouldn't generate an undefined shift!"); 1213f749a5dSCraig Topper assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift"); 1223f749a5dSCraig Topper unsigned PreShift = Magics.PreShift; 1233f749a5dSCraig Topper unsigned PostShift = Magics.PostShift; 1243f749a5dSCraig Topper bool UseNPQ = Magics.IsAdd; 125066b492bSRoman Lebedev 126066b492bSRoman Lebedev APInt NPQFactor = 127012afbbaSRoman Lebedev UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits); 128066b492bSRoman Lebedev 129066b492bSRoman Lebedev APInt Q = Numerator.lshr(PreShift); 130066b492bSRoman Lebedev 131066b492bSRoman Lebedev // Multiply the numerator by the magic value. 132066b492bSRoman Lebedev Q = MULHU(Q, Magics.Magic); 133066b492bSRoman Lebedev 134066b492bSRoman Lebedev if (UseNPQ || ForceNPQ) { 135066b492bSRoman Lebedev APInt NPQ = Numerator - Q; 136066b492bSRoman Lebedev 137066b492bSRoman Lebedev // For vectors we might have a mix of non-NPQ/NPQ paths, so use 138066b492bSRoman Lebedev // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. 139066b492bSRoman Lebedev APInt NPQ_Scalar = NPQ.lshr(1); 140066b492bSRoman Lebedev (void)NPQ_Scalar; 141066b492bSRoman Lebedev NPQ = MULHU(NPQ, NPQFactor); 142066b492bSRoman Lebedev assert(!UseNPQ || NPQ == NPQ_Scalar); 143066b492bSRoman Lebedev 144066b492bSRoman Lebedev Q = NPQ + Q; 145066b492bSRoman Lebedev } 146066b492bSRoman Lebedev 147066b492bSRoman Lebedev Q = Q.lshr(PostShift); 148066b492bSRoman Lebedev 1498bca60fbSCraig Topper return Q; 150066b492bSRoman Lebedev } 151066b492bSRoman Lebedev 152066b492bSRoman Lebedev TEST(UnsignedDivisionByConstantTest, Test) { 153066b492bSRoman Lebedev for (unsigned Bits = 1; Bits <= 32; ++Bits) { 154066b492bSRoman Lebedev if (Bits < 2) 155066b492bSRoman Lebedev continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. 1568abd7008SCraig Topper if (Bits > 10) 157066b492bSRoman Lebedev continue; // Unreasonably slow. 158066b492bSRoman Lebedev EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 159066b492bSRoman Lebedev if (Divisor.isZero()) 160066b492bSRoman Lebedev return; // Division by zero is undefined behavior. 1618bca60fbSCraig Topper if (Divisor.isOne()) 1628bca60fbSCraig Topper return; // Division by one is the numerator. 1638bca60fbSCraig Topper 164066b492bSRoman Lebedev const UnsignedDivisionByConstantInfo Magics = 165066b492bSRoman Lebedev UnsignedDivisionByConstantInfo::get(Divisor); 166066b492bSRoman Lebedev EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 167066b492bSRoman Lebedev APInt NativeResult = Numerator.udiv(Divisor); 1688abd7008SCraig Topper for (bool LZOptimization : {true, false}) { 169066b492bSRoman Lebedev for (bool AllowEvenDivisorOptimization : {true, false}) { 170066b492bSRoman Lebedev for (bool ForceNPQ : {false, true}) { 171066b492bSRoman Lebedev APInt MagicResult = UnsignedDivideUsingMagic( 1728abd7008SCraig Topper Numerator, Divisor, LZOptimization, 1738abd7008SCraig Topper AllowEvenDivisorOptimization, ForceNPQ, Magics); 174066b492bSRoman Lebedev ASSERT_EQ(MagicResult, NativeResult) 175066b492bSRoman Lebedev << " ... given the operation: urem i" << Bits << " " 176066b492bSRoman Lebedev << Numerator << ", " << Divisor 1778abd7008SCraig Topper << " (allow LZ optimization = " 1788abd7008SCraig Topper << LZOptimization << ", allow even divisior optimization = " 1798abd7008SCraig Topper << AllowEvenDivisorOptimization << ", force NPQ = " 1808abd7008SCraig Topper << ForceNPQ << ")"; 1818abd7008SCraig Topper } 182066b492bSRoman Lebedev } 183066b492bSRoman Lebedev } 184066b492bSRoman Lebedev }); 185066b492bSRoman Lebedev }); 186066b492bSRoman Lebedev } 187066b492bSRoman Lebedev } 188066b492bSRoman Lebedev 189066b492bSRoman Lebedev } // end anonymous namespace 190