1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/APInt.h" 10 #include "llvm/Support/DivisionByConstantInfo.h" 11 #include "gtest/gtest.h" 12 #include <array> 13 14 using namespace llvm; 15 16 namespace { 17 18 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) { 19 APInt N(Bits, 0); 20 do { 21 TestFn(N); 22 } while (++N != 0); 23 } 24 25 APInt MULHS(APInt X, APInt Y) { 26 unsigned Bits = X.getBitWidth(); 27 unsigned WideBits = 2 * Bits; 28 return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); 29 } 30 31 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, 32 SignedDivisionByConstantInfo Magics) { 33 unsigned Bits = Numerator.getBitWidth(); 34 35 APInt Factor(Bits, 0); 36 APInt ShiftMask(Bits, -1); 37 if (Divisor.isOne() || Divisor.isAllOnes()) { 38 // If d is +1/-1, we just multiply the numerator by +1/-1. 39 Factor = Divisor.getSExtValue(); 40 Magics.Magic = 0; 41 Magics.ShiftAmount = 0; 42 ShiftMask = 0; 43 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { 44 // If d > 0 and m < 0, add the numerator. 45 Factor = 1; 46 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { 47 // If d < 0 and m > 0, subtract the numerator. 48 Factor = -1; 49 } 50 51 // Multiply the numerator by the magic value. 52 APInt Q = MULHS(Numerator, Magics.Magic); 53 54 // (Optionally) Add/subtract the numerator using Factor. 55 Factor = Numerator * Factor; 56 Q = Q + Factor; 57 58 // Shift right algebraic by shift value. 59 Q = Q.ashr(Magics.ShiftAmount); 60 61 // Extract the sign bit, mask it and add it to the quotient. 62 unsigned SignShift = Bits - 1; 63 APInt T = Q.lshr(SignShift); 64 T = T & ShiftMask; 65 return Q + T; 66 } 67 68 TEST(SignedDivisionByConstantTest, Test) { 69 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 70 if (Bits < 3) 71 continue; // Not supported by `SignedDivisionByConstantInfo::get()`. 72 if (Bits > 12) 73 continue; // Unreasonably slow. 74 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 75 if (Divisor.isZero()) 76 return; // Division by zero is undefined behavior. 77 SignedDivisionByConstantInfo Magics; 78 if (!(Divisor.isOne() || Divisor.isAllOnes())) 79 Magics = SignedDivisionByConstantInfo::get(Divisor); 80 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 81 if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) 82 return; // Overflow is undefined behavior. 83 APInt NativeResult = Numerator.sdiv(Divisor); 84 APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); 85 ASSERT_EQ(MagicResult, NativeResult) 86 << " ... given the operation: srem i" << Bits << " " << Numerator 87 << ", " << Divisor; 88 }); 89 }); 90 } 91 } 92 93 APInt MULHU(APInt X, APInt Y) { 94 unsigned Bits = X.getBitWidth(); 95 unsigned WideBits = 2 * Bits; 96 return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); 97 } 98 99 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, 100 bool LZOptimization, 101 bool AllowEvenDivisorOptimization, bool ForceNPQ, 102 UnsignedDivisionByConstantInfo Magics) { 103 assert(!Divisor.isOne() && "Division by 1 is not supported using Magic."); 104 105 unsigned Bits = Numerator.getBitWidth(); 106 107 if (LZOptimization) { 108 unsigned LeadingZeros = Numerator.countl_zero(); 109 // Clip to the number of leading zeros in the divisor. 110 LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero()); 111 if (LeadingZeros > 0) { 112 Magics = UnsignedDivisionByConstantInfo::get( 113 Divisor, LeadingZeros, AllowEvenDivisorOptimization); 114 assert(!Magics.IsAdd && "Should use cheap fixup now"); 115 } 116 } 117 118 assert(Magics.PreShift < Divisor.getBitWidth() && 119 "We shouldn't generate an undefined shift!"); 120 assert(Magics.PostShift < Divisor.getBitWidth() && 121 "We shouldn't generate an undefined shift!"); 122 assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift"); 123 unsigned PreShift = Magics.PreShift; 124 unsigned PostShift = Magics.PostShift; 125 bool UseNPQ = Magics.IsAdd; 126 127 APInt NPQFactor = 128 UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits); 129 130 APInt Q = Numerator.lshr(PreShift); 131 132 // Multiply the numerator by the magic value. 133 Q = MULHU(Q, Magics.Magic); 134 135 if (UseNPQ || ForceNPQ) { 136 APInt NPQ = Numerator - Q; 137 138 // For vectors we might have a mix of non-NPQ/NPQ paths, so use 139 // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. 140 APInt NPQ_Scalar = NPQ.lshr(1); 141 (void)NPQ_Scalar; 142 NPQ = MULHU(NPQ, NPQFactor); 143 assert(!UseNPQ || NPQ == NPQ_Scalar); 144 145 Q = NPQ + Q; 146 } 147 148 Q = Q.lshr(PostShift); 149 150 return Q; 151 } 152 153 TEST(UnsignedDivisionByConstantTest, Test) { 154 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 155 if (Bits < 2) 156 continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. 157 if (Bits > 10) 158 continue; // Unreasonably slow. 159 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 160 if (Divisor.isZero()) 161 return; // Division by zero is undefined behavior. 162 if (Divisor.isOne()) 163 return; // Division by one is the numerator. 164 165 const UnsignedDivisionByConstantInfo Magics = 166 UnsignedDivisionByConstantInfo::get(Divisor); 167 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 168 APInt NativeResult = Numerator.udiv(Divisor); 169 for (bool LZOptimization : {true, false}) { 170 for (bool AllowEvenDivisorOptimization : {true, false}) { 171 for (bool ForceNPQ : {false, true}) { 172 APInt MagicResult = UnsignedDivideUsingMagic( 173 Numerator, Divisor, LZOptimization, 174 AllowEvenDivisorOptimization, ForceNPQ, Magics); 175 ASSERT_EQ(MagicResult, NativeResult) 176 << " ... given the operation: urem i" << Bits << " " 177 << Numerator << ", " << Divisor 178 << " (allow LZ optimization = " 179 << LZOptimization << ", allow even divisior optimization = " 180 << AllowEvenDivisorOptimization << ", force NPQ = " 181 << ForceNPQ << ")"; 182 } 183 } 184 } 185 }); 186 }); 187 } 188 } 189 190 } // end anonymous namespace 191