xref: /llvm-project/llvm/unittests/Support/DivisionByConstantTest.cpp (revision 8bca60fb0aa5f82d32e619d75e65d0dea2472722)
1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/APInt.h"
10 #include "llvm/Support/DivisionByConstantInfo.h"
11 #include "gtest/gtest.h"
12 #include <array>
13 #include <optional>
14 
15 using namespace llvm;
16 
17 namespace {
18 
19 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
20   APInt N(Bits, 0);
21   do {
22     TestFn(N);
23   } while (++N != 0);
24 }
25 
26 APInt MULHS(APInt X, APInt Y) {
27   unsigned Bits = X.getBitWidth();
28   unsigned WideBits = 2 * Bits;
29   return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
30 }
31 
32 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
33                              SignedDivisionByConstantInfo Magics) {
34   unsigned Bits = Numerator.getBitWidth();
35 
36   APInt Factor(Bits, 0);
37   APInt ShiftMask(Bits, -1);
38   if (Divisor.isOne() || Divisor.isAllOnes()) {
39     // If d is +1/-1, we just multiply the numerator by +1/-1.
40     Factor = Divisor.getSExtValue();
41     Magics.Magic = 0;
42     Magics.ShiftAmount = 0;
43     ShiftMask = 0;
44   } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
45     // If d > 0 and m < 0, add the numerator.
46     Factor = 1;
47   } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
48     // If d < 0 and m > 0, subtract the numerator.
49     Factor = -1;
50   }
51 
52   // Multiply the numerator by the magic value.
53   APInt Q = MULHS(Numerator, Magics.Magic);
54 
55   // (Optionally) Add/subtract the numerator using Factor.
56   Factor = Numerator * Factor;
57   Q = Q + Factor;
58 
59   // Shift right algebraic by shift value.
60   Q = Q.ashr(Magics.ShiftAmount);
61 
62   // Extract the sign bit, mask it and add it to the quotient.
63   unsigned SignShift = Bits - 1;
64   APInt T = Q.lshr(SignShift);
65   T = T & ShiftMask;
66   return Q + T;
67 }
68 
69 TEST(SignedDivisionByConstantTest, Test) {
70   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
71     if (Bits < 3)
72       continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
73     if (Bits > 12)
74       continue; // Unreasonably slow.
75     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
76       if (Divisor.isZero())
77         return; // Division by zero is undefined behavior.
78       SignedDivisionByConstantInfo Magics;
79       if (!(Divisor.isOne() || Divisor.isAllOnes()))
80         Magics = SignedDivisionByConstantInfo::get(Divisor);
81       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
82         if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
83           return; // Overflow is undefined behavior.
84         APInt NativeResult = Numerator.sdiv(Divisor);
85         APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
86         ASSERT_EQ(MagicResult, NativeResult)
87             << " ... given the operation:  srem i" << Bits << " " << Numerator
88             << ", " << Divisor;
89       });
90     });
91   }
92 }
93 
94 APInt MULHU(APInt X, APInt Y) {
95   unsigned Bits = X.getBitWidth();
96   unsigned WideBits = 2 * Bits;
97   return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
98 }
99 
100 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
101                                bool LZOptimization,
102                                bool AllowEvenDivisorOptimization, bool ForceNPQ,
103                                UnsignedDivisionByConstantInfo Magics) {
104   assert(!Divisor.isOne() && "Division by 1 is not supported using Magic.");
105 
106   unsigned Bits = Numerator.getBitWidth();
107 
108   if (LZOptimization) {
109     unsigned LeadingZeros = Numerator.countLeadingZeros();
110     // Clip to the number of leading zeros in the divisor.
111     LeadingZeros = std::min(LeadingZeros, Divisor.countLeadingZeros());
112     if (LeadingZeros > 0) {
113       Magics = UnsignedDivisionByConstantInfo::get(
114           Divisor, LeadingZeros, AllowEvenDivisorOptimization);
115       assert(!Magics.IsAdd && "Should use cheap fixup now");
116     }
117   }
118 
119   unsigned PreShift = 0;
120   unsigned PostShift = 0;
121   bool UseNPQ = false;
122   if (!Magics.IsAdd) {
123     assert(Magics.ShiftAmount < Divisor.getBitWidth() &&
124            "We shouldn't generate an undefined shift!");
125     PreShift = Magics.PreShift;
126     PostShift = Magics.ShiftAmount;
127     UseNPQ = false;
128   } else {
129     assert(Magics.PreShift == 0 && "Unexpected pre-shift");
130     PostShift = Magics.ShiftAmount - 1;
131     assert(PostShift < Divisor.getBitWidth() &&
132            "We shouldn't generate an undefined shift!");
133     UseNPQ = true;
134   }
135 
136   APInt NPQFactor =
137       UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits);
138 
139   APInt Q = Numerator.lshr(PreShift);
140 
141   // Multiply the numerator by the magic value.
142   Q = MULHU(Q, Magics.Magic);
143 
144   if (UseNPQ || ForceNPQ) {
145     APInt NPQ = Numerator - Q;
146 
147     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
148     // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
149     APInt NPQ_Scalar = NPQ.lshr(1);
150     (void)NPQ_Scalar;
151     NPQ = MULHU(NPQ, NPQFactor);
152     assert(!UseNPQ || NPQ == NPQ_Scalar);
153 
154     Q = NPQ + Q;
155   }
156 
157   Q = Q.lshr(PostShift);
158 
159   return Q;
160 }
161 
162 TEST(UnsignedDivisionByConstantTest, Test) {
163   for (unsigned Bits = 1; Bits <= 32; ++Bits) {
164     if (Bits < 2)
165       continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
166     if (Bits > 10)
167       continue; // Unreasonably slow.
168     EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
169       if (Divisor.isZero())
170         return; // Division by zero is undefined behavior.
171       if (Divisor.isOne())
172         return; // Division by one is the numerator.
173 
174       const UnsignedDivisionByConstantInfo Magics =
175           UnsignedDivisionByConstantInfo::get(Divisor);
176       EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
177         APInt NativeResult = Numerator.udiv(Divisor);
178         for (bool LZOptimization : {true, false}) {
179           for (bool AllowEvenDivisorOptimization : {true, false}) {
180             for (bool ForceNPQ : {false, true}) {
181               APInt MagicResult = UnsignedDivideUsingMagic(
182                   Numerator, Divisor, LZOptimization,
183                   AllowEvenDivisorOptimization, ForceNPQ, Magics);
184               ASSERT_EQ(MagicResult, NativeResult)
185                     << " ... given the operation:  urem i" << Bits << " "
186                     << Numerator << ", " << Divisor
187                     << " (allow LZ optimization = "
188                     << LZOptimization << ", allow even divisior optimization = "
189                     << AllowEvenDivisorOptimization << ", force NPQ = "
190                     << ForceNPQ << ")";
191             }
192           }
193         }
194       });
195     });
196   }
197 }
198 
199 } // end anonymous namespace
200