xref: /llvm-project/llvm/unittests/Support/DivisionByConstantTest.cpp (revision 96d0a3b5643a2081e39c9302c86b40859a8752d0)
1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/APInt.h"
10 #include "llvm/Support/DivisionByConstantInfo.h"
11 #include "gtest/gtest.h"
12 #include <array>
13 
14 using namespace llvm;
15 
16 namespace {
17 
18 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
19   APInt N(Bits, 0);
20   do {
21     TestFn(N);
22   } while (++N != 0);
23 }
24 
25 APInt MULHS(APInt X, APInt Y) {
26   unsigned Bits = X.getBitWidth();
27   unsigned WideBits = 2 * Bits;
28   return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
29 }
30 
31 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
32                              SignedDivisionByConstantInfo Magics) {
33   unsigned Bits = Numerator.getBitWidth();
34 
35   APInt Factor(Bits, 0);
36   APInt ShiftMask(Bits, -1);
37   if (Divisor.isOne() || Divisor.isAllOnes()) {
38     // If d is +1/-1, we just multiply the numerator by +1/-1.
39     Factor = Divisor.getSExtValue();
40     Magics.Magic = 0;
41     Magics.ShiftAmount = 0;
42     ShiftMask = 0;
43   } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
44     // If d > 0 and m < 0, add the numerator.
45     Factor = 1;
46   } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
47     // If d < 0 and m > 0, subtract the numerator.
48     Factor = -1;
49   }
50 
51   // Multiply the numerator by the magic value.
52   APInt Q = MULHS(Numerator, Magics.Magic);
53 
54   // (Optionally) Add/subtract the numerator using Factor.
55   Factor = Numerator * Factor;
56   Q = Q + Factor;
57 
58   // Shift right algebraic by shift value.
59   Q = Q.ashr(Magics.ShiftAmount);
60 
61   // Extract the sign bit, mask it and add it to the quotient.
62   unsigned SignShift = Bits - 1;
63   APInt T = Q.lshr(SignShift);
64   T = T & ShiftMask;
65   return Q + T;
66 }
67 
68 TEST(SignedDivisionByConstantTest, Test) {
69   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
70     if (Bits < 3)
71       continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
72     if (Bits > 12)
73       continue; // Unreasonably slow.
74     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
75       if (Divisor.isZero())
76         return; // Division by zero is undefined behavior.
77       SignedDivisionByConstantInfo Magics;
78       if (!(Divisor.isOne() || Divisor.isAllOnes()))
79         Magics = SignedDivisionByConstantInfo::get(Divisor);
80       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
81         if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
82           return; // Overflow is undefined behavior.
83         APInt NativeResult = Numerator.sdiv(Divisor);
84         APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
85         ASSERT_EQ(MagicResult, NativeResult)
86             << " ... given the operation:  srem i" << Bits << " " << Numerator
87             << ", " << Divisor;
88       });
89     });
90   }
91 }
92 
93 APInt MULHU(APInt X, APInt Y) {
94   unsigned Bits = X.getBitWidth();
95   unsigned WideBits = 2 * Bits;
96   return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
97 }
98 
99 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
100                                bool LZOptimization,
101                                bool AllowEvenDivisorOptimization, bool ForceNPQ,
102                                UnsignedDivisionByConstantInfo Magics) {
103   assert(!Divisor.isOne() && "Division by 1 is not supported using Magic.");
104 
105   unsigned Bits = Numerator.getBitWidth();
106 
107   if (LZOptimization) {
108     unsigned LeadingZeros = Numerator.countl_zero();
109     // Clip to the number of leading zeros in the divisor.
110     LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero());
111     if (LeadingZeros > 0) {
112       Magics = UnsignedDivisionByConstantInfo::get(
113           Divisor, LeadingZeros, AllowEvenDivisorOptimization);
114       assert(!Magics.IsAdd && "Should use cheap fixup now");
115     }
116   }
117 
118   assert(Magics.PreShift < Divisor.getBitWidth() &&
119          "We shouldn't generate an undefined shift!");
120   assert(Magics.PostShift < Divisor.getBitWidth() &&
121          "We shouldn't generate an undefined shift!");
122   assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift");
123   unsigned PreShift = Magics.PreShift;
124   unsigned PostShift = Magics.PostShift;
125   bool UseNPQ = Magics.IsAdd;
126 
127   APInt NPQFactor =
128       UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits);
129 
130   APInt Q = Numerator.lshr(PreShift);
131 
132   // Multiply the numerator by the magic value.
133   Q = MULHU(Q, Magics.Magic);
134 
135   if (UseNPQ || ForceNPQ) {
136     APInt NPQ = Numerator - Q;
137 
138     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
139     // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
140     APInt NPQ_Scalar = NPQ.lshr(1);
141     (void)NPQ_Scalar;
142     NPQ = MULHU(NPQ, NPQFactor);
143     assert(!UseNPQ || NPQ == NPQ_Scalar);
144 
145     Q = NPQ + Q;
146   }
147 
148   Q = Q.lshr(PostShift);
149 
150   return Q;
151 }
152 
153 TEST(UnsignedDivisionByConstantTest, Test) {
154   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
155     if (Bits < 2)
156       continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
157     if (Bits > 10)
158       continue; // Unreasonably slow.
159     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
160       if (Divisor.isZero())
161         return; // Division by zero is undefined behavior.
162       if (Divisor.isOne())
163         return; // Division by one is the numerator.
164 
165       const UnsignedDivisionByConstantInfo Magics =
166           UnsignedDivisionByConstantInfo::get(Divisor);
167       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
168         APInt NativeResult = Numerator.udiv(Divisor);
169         for (bool LZOptimization : {true, false}) {
170           for (bool AllowEvenDivisorOptimization : {true, false}) {
171             for (bool ForceNPQ : {false, true}) {
172               APInt MagicResult = UnsignedDivideUsingMagic(
173                   Numerator, Divisor, LZOptimization,
174                   AllowEvenDivisorOptimization, ForceNPQ, Magics);
175               ASSERT_EQ(MagicResult, NativeResult)
176                     << " ... given the operation:  urem i" << Bits << " "
177                     << Numerator << ", " << Divisor
178                     << " (allow LZ optimization = "
179                     << LZOptimization << ", allow even divisior optimization = "
180                     << AllowEvenDivisorOptimization << ", force NPQ = "
181                     << ForceNPQ << ")";
182             }
183           }
184         }
185       });
186     });
187   }
188 }
189 
190 } // end anonymous namespace
191