1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/APInt.h" 10 #include "llvm/Support/DivisionByConstantInfo.h" 11 #include "gtest/gtest.h" 12 #include <array> 13 #include <optional> 14 15 using namespace llvm; 16 17 namespace { 18 19 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) { 20 APInt N(Bits, 0); 21 do { 22 TestFn(N); 23 } while (++N != 0); 24 } 25 26 APInt MULHS(APInt X, APInt Y) { 27 unsigned Bits = X.getBitWidth(); 28 unsigned WideBits = 2 * Bits; 29 return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits); 30 } 31 32 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor, 33 SignedDivisionByConstantInfo Magics) { 34 unsigned Bits = Numerator.getBitWidth(); 35 36 APInt Factor(Bits, 0); 37 APInt ShiftMask(Bits, -1); 38 if (Divisor.isOne() || Divisor.isAllOnes()) { 39 // If d is +1/-1, we just multiply the numerator by +1/-1. 40 Factor = Divisor.getSExtValue(); 41 Magics.Magic = 0; 42 Magics.ShiftAmount = 0; 43 ShiftMask = 0; 44 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { 45 // If d > 0 and m < 0, add the numerator. 46 Factor = 1; 47 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { 48 // If d < 0 and m > 0, subtract the numerator. 49 Factor = -1; 50 } 51 52 // Multiply the numerator by the magic value. 53 APInt Q = MULHS(Numerator, Magics.Magic); 54 55 // (Optionally) Add/subtract the numerator using Factor. 56 Factor = Numerator * Factor; 57 Q = Q + Factor; 58 59 // Shift right algebraic by shift value. 60 Q = Q.ashr(Magics.ShiftAmount); 61 62 // Extract the sign bit, mask it and add it to the quotient. 63 unsigned SignShift = Bits - 1; 64 APInt T = Q.lshr(SignShift); 65 T = T & ShiftMask; 66 return Q + T; 67 } 68 69 TEST(SignedDivisionByConstantTest, Test) { 70 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 71 if (Bits < 3) 72 continue; // Not supported by `SignedDivisionByConstantInfo::get()`. 73 if (Bits > 12) 74 continue; // Unreasonably slow. 75 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 76 if (Divisor.isZero()) 77 return; // Division by zero is undefined behavior. 78 SignedDivisionByConstantInfo Magics; 79 if (!(Divisor.isOne() || Divisor.isAllOnes())) 80 Magics = SignedDivisionByConstantInfo::get(Divisor); 81 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 82 if (Numerator.isMinSignedValue() && Divisor.isAllOnes()) 83 return; // Overflow is undefined behavior. 84 APInt NativeResult = Numerator.sdiv(Divisor); 85 APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics); 86 ASSERT_EQ(MagicResult, NativeResult) 87 << " ... given the operation: srem i" << Bits << " " << Numerator 88 << ", " << Divisor; 89 }); 90 }); 91 } 92 } 93 94 APInt MULHU(APInt X, APInt Y) { 95 unsigned Bits = X.getBitWidth(); 96 unsigned WideBits = 2 * Bits; 97 return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits); 98 } 99 100 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor, 101 bool LZOptimization, 102 bool AllowEvenDivisorOptimization, bool ForceNPQ, 103 UnsignedDivisionByConstantInfo Magics) { 104 assert(!Divisor.isOne() && "Division by 1 is not supported using Magic."); 105 106 unsigned Bits = Numerator.getBitWidth(); 107 108 if (LZOptimization) { 109 unsigned LeadingZeros = Numerator.countLeadingZeros(); 110 // Clip to the number of leading zeros in the divisor. 111 LeadingZeros = std::min(LeadingZeros, Divisor.countLeadingZeros()); 112 if (LeadingZeros > 0) { 113 Magics = UnsignedDivisionByConstantInfo::get( 114 Divisor, LeadingZeros, AllowEvenDivisorOptimization); 115 assert(!Magics.IsAdd && "Should use cheap fixup now"); 116 } 117 } 118 119 unsigned PreShift = 0; 120 unsigned PostShift = 0; 121 bool UseNPQ = false; 122 if (!Magics.IsAdd) { 123 assert(Magics.ShiftAmount < Divisor.getBitWidth() && 124 "We shouldn't generate an undefined shift!"); 125 PreShift = Magics.PreShift; 126 PostShift = Magics.ShiftAmount; 127 UseNPQ = false; 128 } else { 129 assert(Magics.PreShift == 0 && "Unexpected pre-shift"); 130 PostShift = Magics.ShiftAmount - 1; 131 assert(PostShift < Divisor.getBitWidth() && 132 "We shouldn't generate an undefined shift!"); 133 UseNPQ = true; 134 } 135 136 APInt NPQFactor = 137 UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits); 138 139 APInt Q = Numerator.lshr(PreShift); 140 141 // Multiply the numerator by the magic value. 142 Q = MULHU(Q, Magics.Magic); 143 144 if (UseNPQ || ForceNPQ) { 145 APInt NPQ = Numerator - Q; 146 147 // For vectors we might have a mix of non-NPQ/NPQ paths, so use 148 // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. 149 APInt NPQ_Scalar = NPQ.lshr(1); 150 (void)NPQ_Scalar; 151 NPQ = MULHU(NPQ, NPQFactor); 152 assert(!UseNPQ || NPQ == NPQ_Scalar); 153 154 Q = NPQ + Q; 155 } 156 157 Q = Q.lshr(PostShift); 158 159 return Q; 160 } 161 162 TEST(UnsignedDivisionByConstantTest, Test) { 163 for (unsigned Bits = 1; Bits <= 32; ++Bits) { 164 if (Bits < 2) 165 continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`. 166 if (Bits > 10) 167 continue; // Unreasonably slow. 168 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) { 169 if (Divisor.isZero()) 170 return; // Division by zero is undefined behavior. 171 if (Divisor.isOne()) 172 return; // Division by one is the numerator. 173 174 const UnsignedDivisionByConstantInfo Magics = 175 UnsignedDivisionByConstantInfo::get(Divisor); 176 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) { 177 APInt NativeResult = Numerator.udiv(Divisor); 178 for (bool LZOptimization : {true, false}) { 179 for (bool AllowEvenDivisorOptimization : {true, false}) { 180 for (bool ForceNPQ : {false, true}) { 181 APInt MagicResult = UnsignedDivideUsingMagic( 182 Numerator, Divisor, LZOptimization, 183 AllowEvenDivisorOptimization, ForceNPQ, Magics); 184 ASSERT_EQ(MagicResult, NativeResult) 185 << " ... given the operation: urem i" << Bits << " " 186 << Numerator << ", " << Divisor 187 << " (allow LZ optimization = " 188 << LZOptimization << ", allow even divisior optimization = " 189 << AllowEvenDivisorOptimization << ", force NPQ = " 190 << ForceNPQ << ")"; 191 } 192 } 193 } 194 }); 195 }); 196 } 197 } 198 199 } // end anonymous namespace 200