xref: /llvm-project/llvm/unittests/Support/DivisionByConstantTest.cpp (revision 37e5319a12ba47c18049728804d3d1e1b10c4eb4)
1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/APInt.h"
10 #include "llvm/Support/DivisionByConstantInfo.h"
11 #include "gtest/gtest.h"
12 
13 using namespace llvm;
14 
15 namespace {
16 
17 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
18   APInt N(Bits, 0);
19   do {
20     TestFn(N);
21   } while (++N != 0);
22 }
23 
24 APInt MULHS(APInt X, APInt Y) {
25   unsigned Bits = X.getBitWidth();
26   unsigned WideBits = 2 * Bits;
27   return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
28 }
29 
30 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
31                              SignedDivisionByConstantInfo Magics) {
32   unsigned Bits = Numerator.getBitWidth();
33 
34   APInt Factor(Bits, 0);
35   APInt ShiftMask(Bits, -1, true);
36   if (Divisor.isOne() || Divisor.isAllOnes()) {
37     // If d is +1/-1, we just multiply the numerator by +1/-1.
38     Factor = Divisor.getSExtValue();
39     Magics.Magic = 0;
40     Magics.ShiftAmount = 0;
41     ShiftMask = 0;
42   } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
43     // If d > 0 and m < 0, add the numerator.
44     Factor = 1;
45   } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
46     // If d < 0 and m > 0, subtract the numerator.
47     Factor = -1;
48   }
49 
50   // Multiply the numerator by the magic value.
51   APInt Q = MULHS(Numerator, Magics.Magic);
52 
53   // (Optionally) Add/subtract the numerator using Factor.
54   Factor = Numerator * Factor;
55   Q = Q + Factor;
56 
57   // Shift right algebraic by shift value.
58   Q = Q.ashr(Magics.ShiftAmount);
59 
60   // Extract the sign bit, mask it and add it to the quotient.
61   unsigned SignShift = Bits - 1;
62   APInt T = Q.lshr(SignShift);
63   T = T & ShiftMask;
64   return Q + T;
65 }
66 
67 TEST(SignedDivisionByConstantTest, Test) {
68   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
69     if (Bits < 3)
70       continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
71     if (Bits > 12)
72       continue; // Unreasonably slow.
73     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
74       if (Divisor.isZero())
75         return; // Division by zero is undefined behavior.
76       SignedDivisionByConstantInfo Magics;
77       if (!(Divisor.isOne() || Divisor.isAllOnes()))
78         Magics = SignedDivisionByConstantInfo::get(Divisor);
79       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
80         if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
81           return; // Overflow is undefined behavior.
82         APInt NativeResult = Numerator.sdiv(Divisor);
83         APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
84         ASSERT_EQ(MagicResult, NativeResult)
85             << " ... given the operation:  srem i" << Bits << " " << Numerator
86             << ", " << Divisor;
87       });
88     });
89   }
90 }
91 
92 APInt MULHU(APInt X, APInt Y) {
93   unsigned Bits = X.getBitWidth();
94   unsigned WideBits = 2 * Bits;
95   return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
96 }
97 
98 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
99                                bool LZOptimization,
100                                bool AllowEvenDivisorOptimization, bool ForceNPQ,
101                                UnsignedDivisionByConstantInfo Magics) {
102   assert(!Divisor.isOne() && "Division by 1 is not supported using Magic.");
103 
104   unsigned Bits = Numerator.getBitWidth();
105 
106   if (LZOptimization) {
107     unsigned LeadingZeros = Numerator.countl_zero();
108     // Clip to the number of leading zeros in the divisor.
109     LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero());
110     if (LeadingZeros > 0) {
111       Magics = UnsignedDivisionByConstantInfo::get(
112           Divisor, LeadingZeros, AllowEvenDivisorOptimization);
113       assert(!Magics.IsAdd && "Should use cheap fixup now");
114     }
115   }
116 
117   assert(Magics.PreShift < Divisor.getBitWidth() &&
118          "We shouldn't generate an undefined shift!");
119   assert(Magics.PostShift < Divisor.getBitWidth() &&
120          "We shouldn't generate an undefined shift!");
121   assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift");
122   unsigned PreShift = Magics.PreShift;
123   unsigned PostShift = Magics.PostShift;
124   bool UseNPQ = Magics.IsAdd;
125 
126   APInt NPQFactor =
127       UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits);
128 
129   APInt Q = Numerator.lshr(PreShift);
130 
131   // Multiply the numerator by the magic value.
132   Q = MULHU(Q, Magics.Magic);
133 
134   if (UseNPQ || ForceNPQ) {
135     APInt NPQ = Numerator - Q;
136 
137     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
138     // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
139     APInt NPQ_Scalar = NPQ.lshr(1);
140     (void)NPQ_Scalar;
141     NPQ = MULHU(NPQ, NPQFactor);
142     assert(!UseNPQ || NPQ == NPQ_Scalar);
143 
144     Q = NPQ + Q;
145   }
146 
147   Q = Q.lshr(PostShift);
148 
149   return Q;
150 }
151 
152 TEST(UnsignedDivisionByConstantTest, Test) {
153   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
154     if (Bits < 2)
155       continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
156     if (Bits > 10)
157       continue; // Unreasonably slow.
158     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
159       if (Divisor.isZero())
160         return; // Division by zero is undefined behavior.
161       if (Divisor.isOne())
162         return; // Division by one is the numerator.
163 
164       const UnsignedDivisionByConstantInfo Magics =
165           UnsignedDivisionByConstantInfo::get(Divisor);
166       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
167         APInt NativeResult = Numerator.udiv(Divisor);
168         for (bool LZOptimization : {true, false}) {
169           for (bool AllowEvenDivisorOptimization : {true, false}) {
170             for (bool ForceNPQ : {false, true}) {
171               APInt MagicResult = UnsignedDivideUsingMagic(
172                   Numerator, Divisor, LZOptimization,
173                   AllowEvenDivisorOptimization, ForceNPQ, Magics);
174               ASSERT_EQ(MagicResult, NativeResult)
175                     << " ... given the operation:  urem i" << Bits << " "
176                     << Numerator << ", " << Divisor
177                     << " (allow LZ optimization = "
178                     << LZOptimization << ", allow even divisior optimization = "
179                     << AllowEvenDivisorOptimization << ", force NPQ = "
180                     << ForceNPQ << ")";
181             }
182           }
183         }
184       });
185     });
186   }
187 }
188 
189 } // end anonymous namespace
190