1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/APInt.h" 10 #include "llvm/Support/DivisionByConstantInfo.h" 11 #include "gtest/gtest.h" 12 #include <array> 13 #include <optional> 14 15 using namespace llvm; 16 17 namespace { 18 19 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) { 20 APInt N(Bits, 0); 21 do { 22 TestFn(N); 23 } while (++N != 0); 24 } 25 26 APInt MULHS(APInt X, APInt Y) { 27 unsigned Bits = X.getBitWidth(); 28 unsigned WideBits = 2 * Bits; 29 return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); 30 } 31 32 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, 33 SignedDivisionByConstantInfo Magics) { 34 unsigned Bits = Numerator.getBitWidth(); 35 36 APInt Factor(Bits, 0); 37 APInt ShiftMask(Bits, -1); 38 if (Divisor.isOne() || Divisor.isAllOnes()) { 39 // If d is +1/-1, we just multiply the numerator by +1/-1. 40 Factor = Divisor.getSExtValue(); 41 Magics.Magic = 0; 42 Magics.ShiftAmount = 0; 43 ShiftMask = 0; 44 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { 45 // If d > 0 and m < 0, add the numerator. 46 Factor = 1; 47 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { 48 // If d < 0 and m > 0, subtract the numerator. 49 Factor = -1; 50 } 51 52 // Multiply the numerator by the magic value. 53 APInt Q = MULHS(Numerator, Magics.Magic); 54 55 // (Optionally) Add/subtract the numerator using Factor. 56 Factor = Numerator * Factor; 57 Q = Q + Factor; 58 59 // Shift right algebraic by shift value. 60 Q = Q.ashr(Magics.ShiftAmount); 61 62 // Extract the sign bit, mask it and add it to the quotient. 63 unsigned SignShift = Bits - 1; 64 APInt T = Q.lshr(SignShift); 65 T = T & ShiftMask; 66 return Q + T; 67 } 68 69 TEST(SignedDivisionByConstantTest, Test) { 70 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 71 if (Bits < 3) 72 continue; // Not supported by `SignedDivisionByConstantInfo::get()`. 73 if (Bits > 12) 74 continue; // Unreasonably slow. 75 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 76 if (Divisor.isZero()) 77 return; // Division by zero is undefined behavior. 78 SignedDivisionByConstantInfo Magics; 79 if (!(Divisor.isOne() || Divisor.isAllOnes())) 80 Magics = SignedDivisionByConstantInfo::get(Divisor); 81 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 82 if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) 83 return; // Overflow is undefined behavior. 84 APInt NativeResult = Numerator.sdiv(Divisor); 85 APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); 86 ASSERT_EQ(MagicResult, NativeResult) 87 << " ... given the operation: srem i" << Bits << " " << Numerator 88 << ", " << Divisor; 89 }); 90 }); 91 } 92 } 93 94 APInt MULHU(APInt X, APInt Y) { 95 unsigned Bits = X.getBitWidth(); 96 unsigned WideBits = 2 * Bits; 97 return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); 98 } 99 100 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, 101 bool LZOptimization, 102 bool AllowEvenDivisorOptimization, bool ForceNPQ, 103 UnsignedDivisionByConstantInfo Magics) { 104 unsigned Bits = Numerator.getBitWidth(); 105 106 if (LZOptimization && !Divisor.isOne()) { 107 unsigned LeadingZeros = Numerator.countLeadingZeros(); 108 // Clip to the number of leading zeros in the divisor. 109 LeadingZeros = std::min(LeadingZeros, Divisor.countLeadingZeros()); 110 if (LeadingZeros > 0) { 111 Magics = UnsignedDivisionByConstantInfo::get(Divisor, LeadingZeros); 112 assert(!Magics.IsAdd && "Should use cheap fixup now"); 113 } 114 } 115 116 unsigned PreShift = 0; 117 if (AllowEvenDivisorOptimization) { 118 // If the divisor is even, we can avoid using the expensive fixup by 119 // shifting the divided value upfront. 120 if (Magics.IsAdd && !Divisor[0]) { 121 PreShift = Divisor.countTrailingZeros(); 122 // Get magic number for the shifted divisor. 123 Magics = 124 UnsignedDivisionByConstantInfo::get(Divisor.lshr(PreShift), PreShift); 125 assert(!Magics.IsAdd && "Should use cheap fixup now"); 126 } 127 } 128 129 unsigned PostShift = 0; 130 bool UseNPQ = false; 131 if (!Magics.IsAdd || Divisor.isOne()) { 132 assert(Magics.ShiftAmount < Divisor.getBitWidth() && 133 "We shouldn't generate an undefined shift!"); 134 PostShift = Magics.ShiftAmount; 135 UseNPQ = false; 136 } else { 137 PostShift = Magics.ShiftAmount - 1; 138 assert(PostShift < Divisor.getBitWidth() && 139 "We shouldn't generate an undefined shift!"); 140 UseNPQ = true; 141 } 142 143 APInt NPQFactor = 144 UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits); 145 146 APInt Q = Numerator.lshr(PreShift); 147 148 // Multiply the numerator by the magic value. 149 Q = MULHU(Q, Magics.Magic); 150 151 if (UseNPQ || ForceNPQ) { 152 APInt NPQ = Numerator - Q; 153 154 // For vectors we might have a mix of non-NPQ/NPQ paths, so use 155 // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. 156 APInt NPQ_Scalar = NPQ.lshr(1); 157 (void)NPQ_Scalar; 158 NPQ = MULHU(NPQ, NPQFactor); 159 assert(!UseNPQ || NPQ == NPQ_Scalar); 160 161 Q = NPQ + Q; 162 } 163 164 Q = Q.lshr(PostShift); 165 166 return Divisor.isOne() ? Numerator : Q; 167 } 168 169 TEST(UnsignedDivisionByConstantTest, Test) { 170 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 171 if (Bits < 2) 172 continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. 173 if (Bits > 10) 174 continue; // Unreasonably slow. 175 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 176 if (Divisor.isZero()) 177 return; // Division by zero is undefined behavior. 178 const UnsignedDivisionByConstantInfo Magics = 179 UnsignedDivisionByConstantInfo::get(Divisor); 180 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 181 APInt NativeResult = Numerator.udiv(Divisor); 182 for (bool LZOptimization : {true, false}) { 183 for (bool AllowEvenDivisorOptimization : {true, false}) { 184 for (bool ForceNPQ : {false, true}) { 185 APInt MagicResult = UnsignedDivideUsingMagic( 186 Numerator, Divisor, LZOptimization, 187 AllowEvenDivisorOptimization, ForceNPQ, Magics); 188 ASSERT_EQ(MagicResult, NativeResult) 189 << " ... given the operation: urem i" << Bits << " " 190 << Numerator << ", " << Divisor 191 << " (allow LZ optimization = " 192 << LZOptimization << ", allow even divisior optimization = " 193 << AllowEvenDivisorOptimization << ", force NPQ = " 194 << ForceNPQ << ")"; 195 } 196 } 197 } 198 }); 199 }); 200 } 201 } 202 203 } // end anonymous namespace 204