xref: /llvm-project/llvm/unittests/Support/DivisionByConstantTest.cpp (revision 8abd70081f761738e82b37b2891b60ad034f3880)
1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/APInt.h"
10 #include "llvm/Support/DivisionByConstantInfo.h"
11 #include "gtest/gtest.h"
12 #include <array>
13 #include <optional>
14 
15 using namespace llvm;
16 
17 namespace {
18 
19 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
20   APInt N(Bits, 0);
21   do {
22     TestFn(N);
23   } while (++N != 0);
24 }
25 
26 APInt MULHS(APInt X, APInt Y) {
27   unsigned Bits = X.getBitWidth();
28   unsigned WideBits = 2 * Bits;
29   return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
30 }
31 
32 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
33                              SignedDivisionByConstantInfo Magics) {
34   unsigned Bits = Numerator.getBitWidth();
35 
36   APInt Factor(Bits, 0);
37   APInt ShiftMask(Bits, -1);
38   if (Divisor.isOne() || Divisor.isAllOnes()) {
39     // If d is +1/-1, we just multiply the numerator by +1/-1.
40     Factor = Divisor.getSExtValue();
41     Magics.Magic = 0;
42     Magics.ShiftAmount = 0;
43     ShiftMask = 0;
44   } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
45     // If d > 0 and m < 0, add the numerator.
46     Factor = 1;
47   } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
48     // If d < 0 and m > 0, subtract the numerator.
49     Factor = -1;
50   }
51 
52   // Multiply the numerator by the magic value.
53   APInt Q = MULHS(Numerator, Magics.Magic);
54 
55   // (Optionally) Add/subtract the numerator using Factor.
56   Factor = Numerator * Factor;
57   Q = Q + Factor;
58 
59   // Shift right algebraic by shift value.
60   Q = Q.ashr(Magics.ShiftAmount);
61 
62   // Extract the sign bit, mask it and add it to the quotient.
63   unsigned SignShift = Bits - 1;
64   APInt T = Q.lshr(SignShift);
65   T = T & ShiftMask;
66   return Q + T;
67 }
68 
69 TEST(SignedDivisionByConstantTest, Test) {
70   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
71     if (Bits < 3)
72       continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
73     if (Bits > 12)
74       continue; // Unreasonably slow.
75     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
76       if (Divisor.isZero())
77         return; // Division by zero is undefined behavior.
78       SignedDivisionByConstantInfo Magics;
79       if (!(Divisor.isOne() || Divisor.isAllOnes()))
80         Magics = SignedDivisionByConstantInfo::get(Divisor);
81       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
82         if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
83           return; // Overflow is undefined behavior.
84         APInt NativeResult = Numerator.sdiv(Divisor);
85         APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
86         ASSERT_EQ(MagicResult, NativeResult)
87             << " ... given the operation:  srem i" << Bits << " " << Numerator
88             << ", " << Divisor;
89       });
90     });
91   }
92 }
93 
94 APInt MULHU(APInt X, APInt Y) {
95   unsigned Bits = X.getBitWidth();
96   unsigned WideBits = 2 * Bits;
97   return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
98 }
99 
100 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
101                                bool LZOptimization,
102                                bool AllowEvenDivisorOptimization, bool ForceNPQ,
103                                UnsignedDivisionByConstantInfo Magics) {
104   unsigned Bits = Numerator.getBitWidth();
105 
106   if (LZOptimization && !Divisor.isOne()) {
107     unsigned LeadingZeros = Numerator.countLeadingZeros();
108     // Clip to the number of leading zeros in the divisor.
109     LeadingZeros = std::min(LeadingZeros, Divisor.countLeadingZeros());
110     if (LeadingZeros > 0) {
111       Magics = UnsignedDivisionByConstantInfo::get(Divisor, LeadingZeros);
112       assert(!Magics.IsAdd && "Should use cheap fixup now");
113     }
114   }
115 
116   unsigned PreShift = 0;
117   if (AllowEvenDivisorOptimization) {
118     // If the divisor is even, we can avoid using the expensive fixup by
119     // shifting the divided value upfront.
120     if (Magics.IsAdd && !Divisor[0]) {
121       PreShift = Divisor.countTrailingZeros();
122       // Get magic number for the shifted divisor.
123       Magics =
124           UnsignedDivisionByConstantInfo::get(Divisor.lshr(PreShift), PreShift);
125       assert(!Magics.IsAdd && "Should use cheap fixup now");
126     }
127   }
128 
129   unsigned PostShift = 0;
130   bool UseNPQ = false;
131   if (!Magics.IsAdd || Divisor.isOne()) {
132     assert(Magics.ShiftAmount < Divisor.getBitWidth() &&
133            "We shouldn't generate an undefined shift!");
134     PostShift = Magics.ShiftAmount;
135     UseNPQ = false;
136   } else {
137     PostShift = Magics.ShiftAmount - 1;
138     assert(PostShift < Divisor.getBitWidth() &&
139            "We shouldn't generate an undefined shift!");
140     UseNPQ = true;
141   }
142 
143   APInt NPQFactor =
144       UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits);
145 
146   APInt Q = Numerator.lshr(PreShift);
147 
148   // Multiply the numerator by the magic value.
149   Q = MULHU(Q, Magics.Magic);
150 
151   if (UseNPQ || ForceNPQ) {
152     APInt NPQ = Numerator - Q;
153 
154     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
155     // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
156     APInt NPQ_Scalar = NPQ.lshr(1);
157     (void)NPQ_Scalar;
158     NPQ = MULHU(NPQ, NPQFactor);
159     assert(!UseNPQ || NPQ == NPQ_Scalar);
160 
161     Q = NPQ + Q;
162   }
163 
164   Q = Q.lshr(PostShift);
165 
166   return Divisor.isOne() ? Numerator : Q;
167 }
168 
169 TEST(UnsignedDivisionByConstantTest, Test) {
170   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
171     if (Bits < 2)
172       continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
173     if (Bits > 10)
174       continue; // Unreasonably slow.
175     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
176       if (Divisor.isZero())
177         return; // Division by zero is undefined behavior.
178       const UnsignedDivisionByConstantInfo Magics =
179           UnsignedDivisionByConstantInfo::get(Divisor);
180       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
181         APInt NativeResult = Numerator.udiv(Divisor);
182         for (bool LZOptimization : {true, false}) {
183           for (bool AllowEvenDivisorOptimization : {true, false}) {
184             for (bool ForceNPQ : {false, true}) {
185               APInt MagicResult = UnsignedDivideUsingMagic(
186                   Numerator, Divisor, LZOptimization,
187                   AllowEvenDivisorOptimization, ForceNPQ, Magics);
188               ASSERT_EQ(MagicResult, NativeResult)
189                     << " ... given the operation:  urem i" << Bits << " "
190                     << Numerator << ", " << Divisor
191                     << " (allow LZ optimization = "
192                     << LZOptimization << ", allow even divisior optimization = "
193                     << AllowEvenDivisorOptimization << ", force NPQ = "
194                     << ForceNPQ << ")";
195             }
196           }
197         }
198       });
199     });
200   }
201 }
202 
203 } // end anonymous namespace
204