xref: /openbsd-src/gnu/llvm/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (revision 09467b48e8bc8b4905716062da846024139afbf2)
1*09467b48Spatrick //===- InstCombineAddSub.cpp ------------------------------------*- C++ -*-===//
2*09467b48Spatrick //
3*09467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*09467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
5*09467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*09467b48Spatrick //
7*09467b48Spatrick //===----------------------------------------------------------------------===//
8*09467b48Spatrick //
9*09467b48Spatrick // This file implements the visit functions for add, fadd, sub, and fsub.
10*09467b48Spatrick //
11*09467b48Spatrick //===----------------------------------------------------------------------===//
12*09467b48Spatrick 
13*09467b48Spatrick #include "InstCombineInternal.h"
14*09467b48Spatrick #include "llvm/ADT/APFloat.h"
15*09467b48Spatrick #include "llvm/ADT/APInt.h"
16*09467b48Spatrick #include "llvm/ADT/STLExtras.h"
17*09467b48Spatrick #include "llvm/ADT/SmallVector.h"
18*09467b48Spatrick #include "llvm/Analysis/InstructionSimplify.h"
19*09467b48Spatrick #include "llvm/Analysis/ValueTracking.h"
20*09467b48Spatrick #include "llvm/IR/Constant.h"
21*09467b48Spatrick #include "llvm/IR/Constants.h"
22*09467b48Spatrick #include "llvm/IR/InstrTypes.h"
23*09467b48Spatrick #include "llvm/IR/Instruction.h"
24*09467b48Spatrick #include "llvm/IR/Instructions.h"
25*09467b48Spatrick #include "llvm/IR/Operator.h"
26*09467b48Spatrick #include "llvm/IR/PatternMatch.h"
27*09467b48Spatrick #include "llvm/IR/Type.h"
28*09467b48Spatrick #include "llvm/IR/Value.h"
29*09467b48Spatrick #include "llvm/Support/AlignOf.h"
30*09467b48Spatrick #include "llvm/Support/Casting.h"
31*09467b48Spatrick #include "llvm/Support/KnownBits.h"
32*09467b48Spatrick #include <cassert>
33*09467b48Spatrick #include <utility>
34*09467b48Spatrick 
35*09467b48Spatrick using namespace llvm;
36*09467b48Spatrick using namespace PatternMatch;
37*09467b48Spatrick 
38*09467b48Spatrick #define DEBUG_TYPE "instcombine"
39*09467b48Spatrick 
40*09467b48Spatrick namespace {
41*09467b48Spatrick 
42*09467b48Spatrick   /// Class representing coefficient of floating-point addend.
43*09467b48Spatrick   /// This class needs to be highly efficient, which is especially true for
44*09467b48Spatrick   /// the constructor. As of I write this comment, the cost of the default
45*09467b48Spatrick   /// constructor is merely 4-byte-store-zero (Assuming compiler is able to
46*09467b48Spatrick   /// perform write-merging).
47*09467b48Spatrick   ///
48*09467b48Spatrick   class FAddendCoef {
49*09467b48Spatrick   public:
50*09467b48Spatrick     // The constructor has to initialize a APFloat, which is unnecessary for
51*09467b48Spatrick     // most addends which have coefficient either 1 or -1. So, the constructor
52*09467b48Spatrick     // is expensive. In order to avoid the cost of the constructor, we should
53*09467b48Spatrick     // reuse some instances whenever possible. The pre-created instances
54*09467b48Spatrick     // FAddCombine::Add[0-5] embodies this idea.
55*09467b48Spatrick     FAddendCoef() = default;
56*09467b48Spatrick     ~FAddendCoef();
57*09467b48Spatrick 
58*09467b48Spatrick     // If possible, don't define operator+/operator- etc because these
59*09467b48Spatrick     // operators inevitably call FAddendCoef's constructor which is not cheap.
60*09467b48Spatrick     void operator=(const FAddendCoef &A);
61*09467b48Spatrick     void operator+=(const FAddendCoef &A);
62*09467b48Spatrick     void operator*=(const FAddendCoef &S);
63*09467b48Spatrick 
64*09467b48Spatrick     void set(short C) {
65*09467b48Spatrick       assert(!insaneIntVal(C) && "Insane coefficient");
66*09467b48Spatrick       IsFp = false; IntVal = C;
67*09467b48Spatrick     }
68*09467b48Spatrick 
69*09467b48Spatrick     void set(const APFloat& C);
70*09467b48Spatrick 
71*09467b48Spatrick     void negate();
72*09467b48Spatrick 
73*09467b48Spatrick     bool isZero() const { return isInt() ? !IntVal : getFpVal().isZero(); }
74*09467b48Spatrick     Value *getValue(Type *) const;
75*09467b48Spatrick 
76*09467b48Spatrick     bool isOne() const { return isInt() && IntVal == 1; }
77*09467b48Spatrick     bool isTwo() const { return isInt() && IntVal == 2; }
78*09467b48Spatrick     bool isMinusOne() const { return isInt() && IntVal == -1; }
79*09467b48Spatrick     bool isMinusTwo() const { return isInt() && IntVal == -2; }
80*09467b48Spatrick 
81*09467b48Spatrick   private:
82*09467b48Spatrick     bool insaneIntVal(int V) { return V > 4 || V < -4; }
83*09467b48Spatrick 
84*09467b48Spatrick     APFloat *getFpValPtr()
85*09467b48Spatrick       { return reinterpret_cast<APFloat *>(&FpValBuf.buffer[0]); }
86*09467b48Spatrick 
87*09467b48Spatrick     const APFloat *getFpValPtr() const
88*09467b48Spatrick       { return reinterpret_cast<const APFloat *>(&FpValBuf.buffer[0]); }
89*09467b48Spatrick 
90*09467b48Spatrick     const APFloat &getFpVal() const {
91*09467b48Spatrick       assert(IsFp && BufHasFpVal && "Incorret state");
92*09467b48Spatrick       return *getFpValPtr();
93*09467b48Spatrick     }
94*09467b48Spatrick 
95*09467b48Spatrick     APFloat &getFpVal() {
96*09467b48Spatrick       assert(IsFp && BufHasFpVal && "Incorret state");
97*09467b48Spatrick       return *getFpValPtr();
98*09467b48Spatrick     }
99*09467b48Spatrick 
100*09467b48Spatrick     bool isInt() const { return !IsFp; }
101*09467b48Spatrick 
102*09467b48Spatrick     // If the coefficient is represented by an integer, promote it to a
103*09467b48Spatrick     // floating point.
104*09467b48Spatrick     void convertToFpType(const fltSemantics &Sem);
105*09467b48Spatrick 
106*09467b48Spatrick     // Construct an APFloat from a signed integer.
107*09467b48Spatrick     // TODO: We should get rid of this function when APFloat can be constructed
108*09467b48Spatrick     //       from an *SIGNED* integer.
109*09467b48Spatrick     APFloat createAPFloatFromInt(const fltSemantics &Sem, int Val);
110*09467b48Spatrick 
111*09467b48Spatrick     bool IsFp = false;
112*09467b48Spatrick 
113*09467b48Spatrick     // True iff FpValBuf contains an instance of APFloat.
114*09467b48Spatrick     bool BufHasFpVal = false;
115*09467b48Spatrick 
116*09467b48Spatrick     // The integer coefficient of an individual addend is either 1 or -1,
117*09467b48Spatrick     // and we try to simplify at most 4 addends from neighboring at most
118*09467b48Spatrick     // two instructions. So the range of <IntVal> falls in [-4, 4]. APInt
119*09467b48Spatrick     // is overkill of this end.
120*09467b48Spatrick     short IntVal = 0;
121*09467b48Spatrick 
122*09467b48Spatrick     AlignedCharArrayUnion<APFloat> FpValBuf;
123*09467b48Spatrick   };
124*09467b48Spatrick 
125*09467b48Spatrick   /// FAddend is used to represent floating-point addend. An addend is
126*09467b48Spatrick   /// represented as <C, V>, where the V is a symbolic value, and C is a
127*09467b48Spatrick   /// constant coefficient. A constant addend is represented as <C, 0>.
128*09467b48Spatrick   class FAddend {
129*09467b48Spatrick   public:
130*09467b48Spatrick     FAddend() = default;
131*09467b48Spatrick 
132*09467b48Spatrick     void operator+=(const FAddend &T) {
133*09467b48Spatrick       assert((Val == T.Val) && "Symbolic-values disagree");
134*09467b48Spatrick       Coeff += T.Coeff;
135*09467b48Spatrick     }
136*09467b48Spatrick 
137*09467b48Spatrick     Value *getSymVal() const { return Val; }
138*09467b48Spatrick     const FAddendCoef &getCoef() const { return Coeff; }
139*09467b48Spatrick 
140*09467b48Spatrick     bool isConstant() const { return Val == nullptr; }
141*09467b48Spatrick     bool isZero() const { return Coeff.isZero(); }
142*09467b48Spatrick 
143*09467b48Spatrick     void set(short Coefficient, Value *V) {
144*09467b48Spatrick       Coeff.set(Coefficient);
145*09467b48Spatrick       Val = V;
146*09467b48Spatrick     }
147*09467b48Spatrick     void set(const APFloat &Coefficient, Value *V) {
148*09467b48Spatrick       Coeff.set(Coefficient);
149*09467b48Spatrick       Val = V;
150*09467b48Spatrick     }
151*09467b48Spatrick     void set(const ConstantFP *Coefficient, Value *V) {
152*09467b48Spatrick       Coeff.set(Coefficient->getValueAPF());
153*09467b48Spatrick       Val = V;
154*09467b48Spatrick     }
155*09467b48Spatrick 
156*09467b48Spatrick     void negate() { Coeff.negate(); }
157*09467b48Spatrick 
158*09467b48Spatrick     /// Drill down the U-D chain one step to find the definition of V, and
159*09467b48Spatrick     /// try to break the definition into one or two addends.
160*09467b48Spatrick     static unsigned drillValueDownOneStep(Value* V, FAddend &A0, FAddend &A1);
161*09467b48Spatrick 
162*09467b48Spatrick     /// Similar to FAddend::drillDownOneStep() except that the value being
163*09467b48Spatrick     /// splitted is the addend itself.
164*09467b48Spatrick     unsigned drillAddendDownOneStep(FAddend &Addend0, FAddend &Addend1) const;
165*09467b48Spatrick 
166*09467b48Spatrick   private:
167*09467b48Spatrick     void Scale(const FAddendCoef& ScaleAmt) { Coeff *= ScaleAmt; }
168*09467b48Spatrick 
169*09467b48Spatrick     // This addend has the value of "Coeff * Val".
170*09467b48Spatrick     Value *Val = nullptr;
171*09467b48Spatrick     FAddendCoef Coeff;
172*09467b48Spatrick   };
173*09467b48Spatrick 
174*09467b48Spatrick   /// FAddCombine is the class for optimizing an unsafe fadd/fsub along
175*09467b48Spatrick   /// with its neighboring at most two instructions.
176*09467b48Spatrick   ///
177*09467b48Spatrick   class FAddCombine {
178*09467b48Spatrick   public:
179*09467b48Spatrick     FAddCombine(InstCombiner::BuilderTy &B) : Builder(B) {}
180*09467b48Spatrick 
181*09467b48Spatrick     Value *simplify(Instruction *FAdd);
182*09467b48Spatrick 
183*09467b48Spatrick   private:
184*09467b48Spatrick     using AddendVect = SmallVector<const FAddend *, 4>;
185*09467b48Spatrick 
186*09467b48Spatrick     Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota);
187*09467b48Spatrick 
188*09467b48Spatrick     /// Convert given addend to a Value
189*09467b48Spatrick     Value *createAddendVal(const FAddend &A, bool& NeedNeg);
190*09467b48Spatrick 
191*09467b48Spatrick     /// Return the number of instructions needed to emit the N-ary addition.
192*09467b48Spatrick     unsigned calcInstrNumber(const AddendVect& Vect);
193*09467b48Spatrick 
194*09467b48Spatrick     Value *createFSub(Value *Opnd0, Value *Opnd1);
195*09467b48Spatrick     Value *createFAdd(Value *Opnd0, Value *Opnd1);
196*09467b48Spatrick     Value *createFMul(Value *Opnd0, Value *Opnd1);
197*09467b48Spatrick     Value *createFNeg(Value *V);
198*09467b48Spatrick     Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota);
199*09467b48Spatrick     void createInstPostProc(Instruction *NewInst, bool NoNumber = false);
200*09467b48Spatrick 
201*09467b48Spatrick      // Debugging stuff are clustered here.
202*09467b48Spatrick     #ifndef NDEBUG
203*09467b48Spatrick       unsigned CreateInstrNum;
204*09467b48Spatrick       void initCreateInstNum() { CreateInstrNum = 0; }
205*09467b48Spatrick       void incCreateInstNum() { CreateInstrNum++; }
206*09467b48Spatrick     #else
207*09467b48Spatrick       void initCreateInstNum() {}
208*09467b48Spatrick       void incCreateInstNum() {}
209*09467b48Spatrick     #endif
210*09467b48Spatrick 
211*09467b48Spatrick     InstCombiner::BuilderTy &Builder;
212*09467b48Spatrick     Instruction *Instr = nullptr;
213*09467b48Spatrick   };
214*09467b48Spatrick 
215*09467b48Spatrick } // end anonymous namespace
216*09467b48Spatrick 
217*09467b48Spatrick //===----------------------------------------------------------------------===//
218*09467b48Spatrick //
219*09467b48Spatrick // Implementation of
220*09467b48Spatrick //    {FAddendCoef, FAddend, FAddition, FAddCombine}.
221*09467b48Spatrick //
222*09467b48Spatrick //===----------------------------------------------------------------------===//
223*09467b48Spatrick FAddendCoef::~FAddendCoef() {
224*09467b48Spatrick   if (BufHasFpVal)
225*09467b48Spatrick     getFpValPtr()->~APFloat();
226*09467b48Spatrick }
227*09467b48Spatrick 
228*09467b48Spatrick void FAddendCoef::set(const APFloat& C) {
229*09467b48Spatrick   APFloat *P = getFpValPtr();
230*09467b48Spatrick 
231*09467b48Spatrick   if (isInt()) {
232*09467b48Spatrick     // As the buffer is meanless byte stream, we cannot call
233*09467b48Spatrick     // APFloat::operator=().
234*09467b48Spatrick     new(P) APFloat(C);
235*09467b48Spatrick   } else
236*09467b48Spatrick     *P = C;
237*09467b48Spatrick 
238*09467b48Spatrick   IsFp = BufHasFpVal = true;
239*09467b48Spatrick }
240*09467b48Spatrick 
241*09467b48Spatrick void FAddendCoef::convertToFpType(const fltSemantics &Sem) {
242*09467b48Spatrick   if (!isInt())
243*09467b48Spatrick     return;
244*09467b48Spatrick 
245*09467b48Spatrick   APFloat *P = getFpValPtr();
246*09467b48Spatrick   if (IntVal > 0)
247*09467b48Spatrick     new(P) APFloat(Sem, IntVal);
248*09467b48Spatrick   else {
249*09467b48Spatrick     new(P) APFloat(Sem, 0 - IntVal);
250*09467b48Spatrick     P->changeSign();
251*09467b48Spatrick   }
252*09467b48Spatrick   IsFp = BufHasFpVal = true;
253*09467b48Spatrick }
254*09467b48Spatrick 
255*09467b48Spatrick APFloat FAddendCoef::createAPFloatFromInt(const fltSemantics &Sem, int Val) {
256*09467b48Spatrick   if (Val >= 0)
257*09467b48Spatrick     return APFloat(Sem, Val);
258*09467b48Spatrick 
259*09467b48Spatrick   APFloat T(Sem, 0 - Val);
260*09467b48Spatrick   T.changeSign();
261*09467b48Spatrick 
262*09467b48Spatrick   return T;
263*09467b48Spatrick }
264*09467b48Spatrick 
265*09467b48Spatrick void FAddendCoef::operator=(const FAddendCoef &That) {
266*09467b48Spatrick   if (That.isInt())
267*09467b48Spatrick     set(That.IntVal);
268*09467b48Spatrick   else
269*09467b48Spatrick     set(That.getFpVal());
270*09467b48Spatrick }
271*09467b48Spatrick 
272*09467b48Spatrick void FAddendCoef::operator+=(const FAddendCoef &That) {
273*09467b48Spatrick   enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven;
274*09467b48Spatrick   if (isInt() == That.isInt()) {
275*09467b48Spatrick     if (isInt())
276*09467b48Spatrick       IntVal += That.IntVal;
277*09467b48Spatrick     else
278*09467b48Spatrick       getFpVal().add(That.getFpVal(), RndMode);
279*09467b48Spatrick     return;
280*09467b48Spatrick   }
281*09467b48Spatrick 
282*09467b48Spatrick   if (isInt()) {
283*09467b48Spatrick     const APFloat &T = That.getFpVal();
284*09467b48Spatrick     convertToFpType(T.getSemantics());
285*09467b48Spatrick     getFpVal().add(T, RndMode);
286*09467b48Spatrick     return;
287*09467b48Spatrick   }
288*09467b48Spatrick 
289*09467b48Spatrick   APFloat &T = getFpVal();
290*09467b48Spatrick   T.add(createAPFloatFromInt(T.getSemantics(), That.IntVal), RndMode);
291*09467b48Spatrick }
292*09467b48Spatrick 
293*09467b48Spatrick void FAddendCoef::operator*=(const FAddendCoef &That) {
294*09467b48Spatrick   if (That.isOne())
295*09467b48Spatrick     return;
296*09467b48Spatrick 
297*09467b48Spatrick   if (That.isMinusOne()) {
298*09467b48Spatrick     negate();
299*09467b48Spatrick     return;
300*09467b48Spatrick   }
301*09467b48Spatrick 
302*09467b48Spatrick   if (isInt() && That.isInt()) {
303*09467b48Spatrick     int Res = IntVal * (int)That.IntVal;
304*09467b48Spatrick     assert(!insaneIntVal(Res) && "Insane int value");
305*09467b48Spatrick     IntVal = Res;
306*09467b48Spatrick     return;
307*09467b48Spatrick   }
308*09467b48Spatrick 
309*09467b48Spatrick   const fltSemantics &Semantic =
310*09467b48Spatrick     isInt() ? That.getFpVal().getSemantics() : getFpVal().getSemantics();
311*09467b48Spatrick 
312*09467b48Spatrick   if (isInt())
313*09467b48Spatrick     convertToFpType(Semantic);
314*09467b48Spatrick   APFloat &F0 = getFpVal();
315*09467b48Spatrick 
316*09467b48Spatrick   if (That.isInt())
317*09467b48Spatrick     F0.multiply(createAPFloatFromInt(Semantic, That.IntVal),
318*09467b48Spatrick                 APFloat::rmNearestTiesToEven);
319*09467b48Spatrick   else
320*09467b48Spatrick     F0.multiply(That.getFpVal(), APFloat::rmNearestTiesToEven);
321*09467b48Spatrick }
322*09467b48Spatrick 
323*09467b48Spatrick void FAddendCoef::negate() {
324*09467b48Spatrick   if (isInt())
325*09467b48Spatrick     IntVal = 0 - IntVal;
326*09467b48Spatrick   else
327*09467b48Spatrick     getFpVal().changeSign();
328*09467b48Spatrick }
329*09467b48Spatrick 
330*09467b48Spatrick Value *FAddendCoef::getValue(Type *Ty) const {
331*09467b48Spatrick   return isInt() ?
332*09467b48Spatrick     ConstantFP::get(Ty, float(IntVal)) :
333*09467b48Spatrick     ConstantFP::get(Ty->getContext(), getFpVal());
334*09467b48Spatrick }
335*09467b48Spatrick 
336*09467b48Spatrick // The definition of <Val>     Addends
337*09467b48Spatrick // =========================================
338*09467b48Spatrick //  A + B                     <1, A>, <1,B>
339*09467b48Spatrick //  A - B                     <1, A>, <1,B>
340*09467b48Spatrick //  0 - B                     <-1, B>
341*09467b48Spatrick //  C * A,                    <C, A>
342*09467b48Spatrick //  A + C                     <1, A> <C, NULL>
343*09467b48Spatrick //  0 +/- 0                   <0, NULL> (corner case)
344*09467b48Spatrick //
345*09467b48Spatrick // Legend: A and B are not constant, C is constant
346*09467b48Spatrick unsigned FAddend::drillValueDownOneStep
347*09467b48Spatrick   (Value *Val, FAddend &Addend0, FAddend &Addend1) {
348*09467b48Spatrick   Instruction *I = nullptr;
349*09467b48Spatrick   if (!Val || !(I = dyn_cast<Instruction>(Val)))
350*09467b48Spatrick     return 0;
351*09467b48Spatrick 
352*09467b48Spatrick   unsigned Opcode = I->getOpcode();
353*09467b48Spatrick 
354*09467b48Spatrick   if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub) {
355*09467b48Spatrick     ConstantFP *C0, *C1;
356*09467b48Spatrick     Value *Opnd0 = I->getOperand(0);
357*09467b48Spatrick     Value *Opnd1 = I->getOperand(1);
358*09467b48Spatrick     if ((C0 = dyn_cast<ConstantFP>(Opnd0)) && C0->isZero())
359*09467b48Spatrick       Opnd0 = nullptr;
360*09467b48Spatrick 
361*09467b48Spatrick     if ((C1 = dyn_cast<ConstantFP>(Opnd1)) && C1->isZero())
362*09467b48Spatrick       Opnd1 = nullptr;
363*09467b48Spatrick 
364*09467b48Spatrick     if (Opnd0) {
365*09467b48Spatrick       if (!C0)
366*09467b48Spatrick         Addend0.set(1, Opnd0);
367*09467b48Spatrick       else
368*09467b48Spatrick         Addend0.set(C0, nullptr);
369*09467b48Spatrick     }
370*09467b48Spatrick 
371*09467b48Spatrick     if (Opnd1) {
372*09467b48Spatrick       FAddend &Addend = Opnd0 ? Addend1 : Addend0;
373*09467b48Spatrick       if (!C1)
374*09467b48Spatrick         Addend.set(1, Opnd1);
375*09467b48Spatrick       else
376*09467b48Spatrick         Addend.set(C1, nullptr);
377*09467b48Spatrick       if (Opcode == Instruction::FSub)
378*09467b48Spatrick         Addend.negate();
379*09467b48Spatrick     }
380*09467b48Spatrick 
381*09467b48Spatrick     if (Opnd0 || Opnd1)
382*09467b48Spatrick       return Opnd0 && Opnd1 ? 2 : 1;
383*09467b48Spatrick 
384*09467b48Spatrick     // Both operands are zero. Weird!
385*09467b48Spatrick     Addend0.set(APFloat(C0->getValueAPF().getSemantics()), nullptr);
386*09467b48Spatrick     return 1;
387*09467b48Spatrick   }
388*09467b48Spatrick 
389*09467b48Spatrick   if (I->getOpcode() == Instruction::FMul) {
390*09467b48Spatrick     Value *V0 = I->getOperand(0);
391*09467b48Spatrick     Value *V1 = I->getOperand(1);
392*09467b48Spatrick     if (ConstantFP *C = dyn_cast<ConstantFP>(V0)) {
393*09467b48Spatrick       Addend0.set(C, V1);
394*09467b48Spatrick       return 1;
395*09467b48Spatrick     }
396*09467b48Spatrick 
397*09467b48Spatrick     if (ConstantFP *C = dyn_cast<ConstantFP>(V1)) {
398*09467b48Spatrick       Addend0.set(C, V0);
399*09467b48Spatrick       return 1;
400*09467b48Spatrick     }
401*09467b48Spatrick   }
402*09467b48Spatrick 
403*09467b48Spatrick   return 0;
404*09467b48Spatrick }
405*09467b48Spatrick 
406*09467b48Spatrick // Try to break *this* addend into two addends. e.g. Suppose this addend is
407*09467b48Spatrick // <2.3, V>, and V = X + Y, by calling this function, we obtain two addends,
408*09467b48Spatrick // i.e. <2.3, X> and <2.3, Y>.
409*09467b48Spatrick unsigned FAddend::drillAddendDownOneStep
410*09467b48Spatrick   (FAddend &Addend0, FAddend &Addend1) const {
411*09467b48Spatrick   if (isConstant())
412*09467b48Spatrick     return 0;
413*09467b48Spatrick 
414*09467b48Spatrick   unsigned BreakNum = FAddend::drillValueDownOneStep(Val, Addend0, Addend1);
415*09467b48Spatrick   if (!BreakNum || Coeff.isOne())
416*09467b48Spatrick     return BreakNum;
417*09467b48Spatrick 
418*09467b48Spatrick   Addend0.Scale(Coeff);
419*09467b48Spatrick 
420*09467b48Spatrick   if (BreakNum == 2)
421*09467b48Spatrick     Addend1.Scale(Coeff);
422*09467b48Spatrick 
423*09467b48Spatrick   return BreakNum;
424*09467b48Spatrick }
425*09467b48Spatrick 
426*09467b48Spatrick Value *FAddCombine::simplify(Instruction *I) {
427*09467b48Spatrick   assert(I->hasAllowReassoc() && I->hasNoSignedZeros() &&
428*09467b48Spatrick          "Expected 'reassoc'+'nsz' instruction");
429*09467b48Spatrick 
430*09467b48Spatrick   // Currently we are not able to handle vector type.
431*09467b48Spatrick   if (I->getType()->isVectorTy())
432*09467b48Spatrick     return nullptr;
433*09467b48Spatrick 
434*09467b48Spatrick   assert((I->getOpcode() == Instruction::FAdd ||
435*09467b48Spatrick           I->getOpcode() == Instruction::FSub) && "Expect add/sub");
436*09467b48Spatrick 
437*09467b48Spatrick   // Save the instruction before calling other member-functions.
438*09467b48Spatrick   Instr = I;
439*09467b48Spatrick 
440*09467b48Spatrick   FAddend Opnd0, Opnd1, Opnd0_0, Opnd0_1, Opnd1_0, Opnd1_1;
441*09467b48Spatrick 
442*09467b48Spatrick   unsigned OpndNum = FAddend::drillValueDownOneStep(I, Opnd0, Opnd1);
443*09467b48Spatrick 
444*09467b48Spatrick   // Step 1: Expand the 1st addend into Opnd0_0 and Opnd0_1.
445*09467b48Spatrick   unsigned Opnd0_ExpNum = 0;
446*09467b48Spatrick   unsigned Opnd1_ExpNum = 0;
447*09467b48Spatrick 
448*09467b48Spatrick   if (!Opnd0.isConstant())
449*09467b48Spatrick     Opnd0_ExpNum = Opnd0.drillAddendDownOneStep(Opnd0_0, Opnd0_1);
450*09467b48Spatrick 
451*09467b48Spatrick   // Step 2: Expand the 2nd addend into Opnd1_0 and Opnd1_1.
452*09467b48Spatrick   if (OpndNum == 2 && !Opnd1.isConstant())
453*09467b48Spatrick     Opnd1_ExpNum = Opnd1.drillAddendDownOneStep(Opnd1_0, Opnd1_1);
454*09467b48Spatrick 
455*09467b48Spatrick   // Step 3: Try to optimize Opnd0_0 + Opnd0_1 + Opnd1_0 + Opnd1_1
456*09467b48Spatrick   if (Opnd0_ExpNum && Opnd1_ExpNum) {
457*09467b48Spatrick     AddendVect AllOpnds;
458*09467b48Spatrick     AllOpnds.push_back(&Opnd0_0);
459*09467b48Spatrick     AllOpnds.push_back(&Opnd1_0);
460*09467b48Spatrick     if (Opnd0_ExpNum == 2)
461*09467b48Spatrick       AllOpnds.push_back(&Opnd0_1);
462*09467b48Spatrick     if (Opnd1_ExpNum == 2)
463*09467b48Spatrick       AllOpnds.push_back(&Opnd1_1);
464*09467b48Spatrick 
465*09467b48Spatrick     // Compute instruction quota. We should save at least one instruction.
466*09467b48Spatrick     unsigned InstQuota = 0;
467*09467b48Spatrick 
468*09467b48Spatrick     Value *V0 = I->getOperand(0);
469*09467b48Spatrick     Value *V1 = I->getOperand(1);
470*09467b48Spatrick     InstQuota = ((!isa<Constant>(V0) && V0->hasOneUse()) &&
471*09467b48Spatrick                  (!isa<Constant>(V1) && V1->hasOneUse())) ? 2 : 1;
472*09467b48Spatrick 
473*09467b48Spatrick     if (Value *R = simplifyFAdd(AllOpnds, InstQuota))
474*09467b48Spatrick       return R;
475*09467b48Spatrick   }
476*09467b48Spatrick 
477*09467b48Spatrick   if (OpndNum != 2) {
478*09467b48Spatrick     // The input instruction is : "I=0.0 +/- V". If the "V" were able to be
479*09467b48Spatrick     // splitted into two addends, say "V = X - Y", the instruction would have
480*09467b48Spatrick     // been optimized into "I = Y - X" in the previous steps.
481*09467b48Spatrick     //
482*09467b48Spatrick     const FAddendCoef &CE = Opnd0.getCoef();
483*09467b48Spatrick     return CE.isOne() ? Opnd0.getSymVal() : nullptr;
484*09467b48Spatrick   }
485*09467b48Spatrick 
486*09467b48Spatrick   // step 4: Try to optimize Opnd0 + Opnd1_0 [+ Opnd1_1]
487*09467b48Spatrick   if (Opnd1_ExpNum) {
488*09467b48Spatrick     AddendVect AllOpnds;
489*09467b48Spatrick     AllOpnds.push_back(&Opnd0);
490*09467b48Spatrick     AllOpnds.push_back(&Opnd1_0);
491*09467b48Spatrick     if (Opnd1_ExpNum == 2)
492*09467b48Spatrick       AllOpnds.push_back(&Opnd1_1);
493*09467b48Spatrick 
494*09467b48Spatrick     if (Value *R = simplifyFAdd(AllOpnds, 1))
495*09467b48Spatrick       return R;
496*09467b48Spatrick   }
497*09467b48Spatrick 
498*09467b48Spatrick   // step 5: Try to optimize Opnd1 + Opnd0_0 [+ Opnd0_1]
499*09467b48Spatrick   if (Opnd0_ExpNum) {
500*09467b48Spatrick     AddendVect AllOpnds;
501*09467b48Spatrick     AllOpnds.push_back(&Opnd1);
502*09467b48Spatrick     AllOpnds.push_back(&Opnd0_0);
503*09467b48Spatrick     if (Opnd0_ExpNum == 2)
504*09467b48Spatrick       AllOpnds.push_back(&Opnd0_1);
505*09467b48Spatrick 
506*09467b48Spatrick     if (Value *R = simplifyFAdd(AllOpnds, 1))
507*09467b48Spatrick       return R;
508*09467b48Spatrick   }
509*09467b48Spatrick 
510*09467b48Spatrick   return nullptr;
511*09467b48Spatrick }
512*09467b48Spatrick 
513*09467b48Spatrick Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
514*09467b48Spatrick   unsigned AddendNum = Addends.size();
515*09467b48Spatrick   assert(AddendNum <= 4 && "Too many addends");
516*09467b48Spatrick 
517*09467b48Spatrick   // For saving intermediate results;
518*09467b48Spatrick   unsigned NextTmpIdx = 0;
519*09467b48Spatrick   FAddend TmpResult[3];
520*09467b48Spatrick 
521*09467b48Spatrick   // Points to the constant addend of the resulting simplified expression.
522*09467b48Spatrick   // If the resulting expr has constant-addend, this constant-addend is
523*09467b48Spatrick   // desirable to reside at the top of the resulting expression tree. Placing
524*09467b48Spatrick   // constant close to supper-expr(s) will potentially reveal some optimization
525*09467b48Spatrick   // opportunities in super-expr(s).
526*09467b48Spatrick   const FAddend *ConstAdd = nullptr;
527*09467b48Spatrick 
528*09467b48Spatrick   // Simplified addends are placed <SimpVect>.
529*09467b48Spatrick   AddendVect SimpVect;
530*09467b48Spatrick 
531*09467b48Spatrick   // The outer loop works on one symbolic-value at a time. Suppose the input
532*09467b48Spatrick   // addends are : <a1, x>, <b1, y>, <a2, x>, <c1, z>, <b2, y>, ...
533*09467b48Spatrick   // The symbolic-values will be processed in this order: x, y, z.
534*09467b48Spatrick   for (unsigned SymIdx = 0; SymIdx < AddendNum; SymIdx++) {
535*09467b48Spatrick 
536*09467b48Spatrick     const FAddend *ThisAddend = Addends[SymIdx];
537*09467b48Spatrick     if (!ThisAddend) {
538*09467b48Spatrick       // This addend was processed before.
539*09467b48Spatrick       continue;
540*09467b48Spatrick     }
541*09467b48Spatrick 
542*09467b48Spatrick     Value *Val = ThisAddend->getSymVal();
543*09467b48Spatrick     unsigned StartIdx = SimpVect.size();
544*09467b48Spatrick     SimpVect.push_back(ThisAddend);
545*09467b48Spatrick 
546*09467b48Spatrick     // The inner loop collects addends sharing same symbolic-value, and these
547*09467b48Spatrick     // addends will be later on folded into a single addend. Following above
548*09467b48Spatrick     // example, if the symbolic value "y" is being processed, the inner loop
549*09467b48Spatrick     // will collect two addends "<b1,y>" and "<b2,Y>". These two addends will
550*09467b48Spatrick     // be later on folded into "<b1+b2, y>".
551*09467b48Spatrick     for (unsigned SameSymIdx = SymIdx + 1;
552*09467b48Spatrick          SameSymIdx < AddendNum; SameSymIdx++) {
553*09467b48Spatrick       const FAddend *T = Addends[SameSymIdx];
554*09467b48Spatrick       if (T && T->getSymVal() == Val) {
555*09467b48Spatrick         // Set null such that next iteration of the outer loop will not process
556*09467b48Spatrick         // this addend again.
557*09467b48Spatrick         Addends[SameSymIdx] = nullptr;
558*09467b48Spatrick         SimpVect.push_back(T);
559*09467b48Spatrick       }
560*09467b48Spatrick     }
561*09467b48Spatrick 
562*09467b48Spatrick     // If multiple addends share same symbolic value, fold them together.
563*09467b48Spatrick     if (StartIdx + 1 != SimpVect.size()) {
564*09467b48Spatrick       FAddend &R = TmpResult[NextTmpIdx ++];
565*09467b48Spatrick       R = *SimpVect[StartIdx];
566*09467b48Spatrick       for (unsigned Idx = StartIdx + 1; Idx < SimpVect.size(); Idx++)
567*09467b48Spatrick         R += *SimpVect[Idx];
568*09467b48Spatrick 
569*09467b48Spatrick       // Pop all addends being folded and push the resulting folded addend.
570*09467b48Spatrick       SimpVect.resize(StartIdx);
571*09467b48Spatrick       if (Val) {
572*09467b48Spatrick         if (!R.isZero()) {
573*09467b48Spatrick           SimpVect.push_back(&R);
574*09467b48Spatrick         }
575*09467b48Spatrick       } else {
576*09467b48Spatrick         // Don't push constant addend at this time. It will be the last element
577*09467b48Spatrick         // of <SimpVect>.
578*09467b48Spatrick         ConstAdd = &R;
579*09467b48Spatrick       }
580*09467b48Spatrick     }
581*09467b48Spatrick   }
582*09467b48Spatrick 
583*09467b48Spatrick   assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) &&
584*09467b48Spatrick          "out-of-bound access");
585*09467b48Spatrick 
586*09467b48Spatrick   if (ConstAdd)
587*09467b48Spatrick     SimpVect.push_back(ConstAdd);
588*09467b48Spatrick 
589*09467b48Spatrick   Value *Result;
590*09467b48Spatrick   if (!SimpVect.empty())
591*09467b48Spatrick     Result = createNaryFAdd(SimpVect, InstrQuota);
592*09467b48Spatrick   else {
593*09467b48Spatrick     // The addition is folded to 0.0.
594*09467b48Spatrick     Result = ConstantFP::get(Instr->getType(), 0.0);
595*09467b48Spatrick   }
596*09467b48Spatrick 
597*09467b48Spatrick   return Result;
598*09467b48Spatrick }
599*09467b48Spatrick 
600*09467b48Spatrick Value *FAddCombine::createNaryFAdd
601*09467b48Spatrick   (const AddendVect &Opnds, unsigned InstrQuota) {
602*09467b48Spatrick   assert(!Opnds.empty() && "Expect at least one addend");
603*09467b48Spatrick 
604*09467b48Spatrick   // Step 1: Check if the # of instructions needed exceeds the quota.
605*09467b48Spatrick 
606*09467b48Spatrick   unsigned InstrNeeded = calcInstrNumber(Opnds);
607*09467b48Spatrick   if (InstrNeeded > InstrQuota)
608*09467b48Spatrick     return nullptr;
609*09467b48Spatrick 
610*09467b48Spatrick   initCreateInstNum();
611*09467b48Spatrick 
612*09467b48Spatrick   // step 2: Emit the N-ary addition.
613*09467b48Spatrick   // Note that at most three instructions are involved in Fadd-InstCombine: the
614*09467b48Spatrick   // addition in question, and at most two neighboring instructions.
615*09467b48Spatrick   // The resulting optimized addition should have at least one less instruction
616*09467b48Spatrick   // than the original addition expression tree. This implies that the resulting
617*09467b48Spatrick   // N-ary addition has at most two instructions, and we don't need to worry
618*09467b48Spatrick   // about tree-height when constructing the N-ary addition.
619*09467b48Spatrick 
620*09467b48Spatrick   Value *LastVal = nullptr;
621*09467b48Spatrick   bool LastValNeedNeg = false;
622*09467b48Spatrick 
623*09467b48Spatrick   // Iterate the addends, creating fadd/fsub using adjacent two addends.
624*09467b48Spatrick   for (const FAddend *Opnd : Opnds) {
625*09467b48Spatrick     bool NeedNeg;
626*09467b48Spatrick     Value *V = createAddendVal(*Opnd, NeedNeg);
627*09467b48Spatrick     if (!LastVal) {
628*09467b48Spatrick       LastVal = V;
629*09467b48Spatrick       LastValNeedNeg = NeedNeg;
630*09467b48Spatrick       continue;
631*09467b48Spatrick     }
632*09467b48Spatrick 
633*09467b48Spatrick     if (LastValNeedNeg == NeedNeg) {
634*09467b48Spatrick       LastVal = createFAdd(LastVal, V);
635*09467b48Spatrick       continue;
636*09467b48Spatrick     }
637*09467b48Spatrick 
638*09467b48Spatrick     if (LastValNeedNeg)
639*09467b48Spatrick       LastVal = createFSub(V, LastVal);
640*09467b48Spatrick     else
641*09467b48Spatrick       LastVal = createFSub(LastVal, V);
642*09467b48Spatrick 
643*09467b48Spatrick     LastValNeedNeg = false;
644*09467b48Spatrick   }
645*09467b48Spatrick 
646*09467b48Spatrick   if (LastValNeedNeg) {
647*09467b48Spatrick     LastVal = createFNeg(LastVal);
648*09467b48Spatrick   }
649*09467b48Spatrick 
650*09467b48Spatrick #ifndef NDEBUG
651*09467b48Spatrick   assert(CreateInstrNum == InstrNeeded &&
652*09467b48Spatrick          "Inconsistent in instruction numbers");
653*09467b48Spatrick #endif
654*09467b48Spatrick 
655*09467b48Spatrick   return LastVal;
656*09467b48Spatrick }
657*09467b48Spatrick 
658*09467b48Spatrick Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) {
659*09467b48Spatrick   Value *V = Builder.CreateFSub(Opnd0, Opnd1);
660*09467b48Spatrick   if (Instruction *I = dyn_cast<Instruction>(V))
661*09467b48Spatrick     createInstPostProc(I);
662*09467b48Spatrick   return V;
663*09467b48Spatrick }
664*09467b48Spatrick 
665*09467b48Spatrick Value *FAddCombine::createFNeg(Value *V) {
666*09467b48Spatrick   Value *Zero = cast<Value>(ConstantFP::getZeroValueForNegation(V->getType()));
667*09467b48Spatrick   Value *NewV = createFSub(Zero, V);
668*09467b48Spatrick   if (Instruction *I = dyn_cast<Instruction>(NewV))
669*09467b48Spatrick     createInstPostProc(I, true); // fneg's don't receive instruction numbers.
670*09467b48Spatrick   return NewV;
671*09467b48Spatrick }
672*09467b48Spatrick 
673*09467b48Spatrick Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) {
674*09467b48Spatrick   Value *V = Builder.CreateFAdd(Opnd0, Opnd1);
675*09467b48Spatrick   if (Instruction *I = dyn_cast<Instruction>(V))
676*09467b48Spatrick     createInstPostProc(I);
677*09467b48Spatrick   return V;
678*09467b48Spatrick }
679*09467b48Spatrick 
680*09467b48Spatrick Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) {
681*09467b48Spatrick   Value *V = Builder.CreateFMul(Opnd0, Opnd1);
682*09467b48Spatrick   if (Instruction *I = dyn_cast<Instruction>(V))
683*09467b48Spatrick     createInstPostProc(I);
684*09467b48Spatrick   return V;
685*09467b48Spatrick }
686*09467b48Spatrick 
687*09467b48Spatrick void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) {
688*09467b48Spatrick   NewInstr->setDebugLoc(Instr->getDebugLoc());
689*09467b48Spatrick 
690*09467b48Spatrick   // Keep track of the number of instruction created.
691*09467b48Spatrick   if (!NoNumber)
692*09467b48Spatrick     incCreateInstNum();
693*09467b48Spatrick 
694*09467b48Spatrick   // Propagate fast-math flags
695*09467b48Spatrick   NewInstr->setFastMathFlags(Instr->getFastMathFlags());
696*09467b48Spatrick }
697*09467b48Spatrick 
698*09467b48Spatrick // Return the number of instruction needed to emit the N-ary addition.
699*09467b48Spatrick // NOTE: Keep this function in sync with createAddendVal().
700*09467b48Spatrick unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) {
701*09467b48Spatrick   unsigned OpndNum = Opnds.size();
702*09467b48Spatrick   unsigned InstrNeeded = OpndNum - 1;
703*09467b48Spatrick 
704*09467b48Spatrick   // The number of addends in the form of "(-1)*x".
705*09467b48Spatrick   unsigned NegOpndNum = 0;
706*09467b48Spatrick 
707*09467b48Spatrick   // Adjust the number of instructions needed to emit the N-ary add.
708*09467b48Spatrick   for (const FAddend *Opnd : Opnds) {
709*09467b48Spatrick     if (Opnd->isConstant())
710*09467b48Spatrick       continue;
711*09467b48Spatrick 
712*09467b48Spatrick     // The constant check above is really for a few special constant
713*09467b48Spatrick     // coefficients.
714*09467b48Spatrick     if (isa<UndefValue>(Opnd->getSymVal()))
715*09467b48Spatrick       continue;
716*09467b48Spatrick 
717*09467b48Spatrick     const FAddendCoef &CE = Opnd->getCoef();
718*09467b48Spatrick     if (CE.isMinusOne() || CE.isMinusTwo())
719*09467b48Spatrick       NegOpndNum++;
720*09467b48Spatrick 
721*09467b48Spatrick     // Let the addend be "c * x". If "c == +/-1", the value of the addend
722*09467b48Spatrick     // is immediately available; otherwise, it needs exactly one instruction
723*09467b48Spatrick     // to evaluate the value.
724*09467b48Spatrick     if (!CE.isMinusOne() && !CE.isOne())
725*09467b48Spatrick       InstrNeeded++;
726*09467b48Spatrick   }
727*09467b48Spatrick   if (NegOpndNum == OpndNum)
728*09467b48Spatrick     InstrNeeded++;
729*09467b48Spatrick   return InstrNeeded;
730*09467b48Spatrick }
731*09467b48Spatrick 
732*09467b48Spatrick // Input Addend        Value           NeedNeg(output)
733*09467b48Spatrick // ================================================================
734*09467b48Spatrick // Constant C          C               false
735*09467b48Spatrick // <+/-1, V>           V               coefficient is -1
736*09467b48Spatrick // <2/-2, V>          "fadd V, V"      coefficient is -2
737*09467b48Spatrick // <C, V>             "fmul V, C"      false
738*09467b48Spatrick //
739*09467b48Spatrick // NOTE: Keep this function in sync with FAddCombine::calcInstrNumber.
740*09467b48Spatrick Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) {
741*09467b48Spatrick   const FAddendCoef &Coeff = Opnd.getCoef();
742*09467b48Spatrick 
743*09467b48Spatrick   if (Opnd.isConstant()) {
744*09467b48Spatrick     NeedNeg = false;
745*09467b48Spatrick     return Coeff.getValue(Instr->getType());
746*09467b48Spatrick   }
747*09467b48Spatrick 
748*09467b48Spatrick   Value *OpndVal = Opnd.getSymVal();
749*09467b48Spatrick 
750*09467b48Spatrick   if (Coeff.isMinusOne() || Coeff.isOne()) {
751*09467b48Spatrick     NeedNeg = Coeff.isMinusOne();
752*09467b48Spatrick     return OpndVal;
753*09467b48Spatrick   }
754*09467b48Spatrick 
755*09467b48Spatrick   if (Coeff.isTwo() || Coeff.isMinusTwo()) {
756*09467b48Spatrick     NeedNeg = Coeff.isMinusTwo();
757*09467b48Spatrick     return createFAdd(OpndVal, OpndVal);
758*09467b48Spatrick   }
759*09467b48Spatrick 
760*09467b48Spatrick   NeedNeg = false;
761*09467b48Spatrick   return createFMul(OpndVal, Coeff.getValue(Instr->getType()));
762*09467b48Spatrick }
763*09467b48Spatrick 
764*09467b48Spatrick // Checks if any operand is negative and we can convert add to sub.
765*09467b48Spatrick // This function checks for following negative patterns
766*09467b48Spatrick //   ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C))
767*09467b48Spatrick //   ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C))
768*09467b48Spatrick //   XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even
769*09467b48Spatrick static Value *checkForNegativeOperand(BinaryOperator &I,
770*09467b48Spatrick                                       InstCombiner::BuilderTy &Builder) {
771*09467b48Spatrick   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
772*09467b48Spatrick 
773*09467b48Spatrick   // This function creates 2 instructions to replace ADD, we need at least one
774*09467b48Spatrick   // of LHS or RHS to have one use to ensure benefit in transform.
775*09467b48Spatrick   if (!LHS->hasOneUse() && !RHS->hasOneUse())
776*09467b48Spatrick     return nullptr;
777*09467b48Spatrick 
778*09467b48Spatrick   Value *X = nullptr, *Y = nullptr, *Z = nullptr;
779*09467b48Spatrick   const APInt *C1 = nullptr, *C2 = nullptr;
780*09467b48Spatrick 
781*09467b48Spatrick   // if ONE is on other side, swap
782*09467b48Spatrick   if (match(RHS, m_Add(m_Value(X), m_One())))
783*09467b48Spatrick     std::swap(LHS, RHS);
784*09467b48Spatrick 
785*09467b48Spatrick   if (match(LHS, m_Add(m_Value(X), m_One()))) {
786*09467b48Spatrick     // if XOR on other side, swap
787*09467b48Spatrick     if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
788*09467b48Spatrick       std::swap(X, RHS);
789*09467b48Spatrick 
790*09467b48Spatrick     if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
791*09467b48Spatrick       // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
792*09467b48Spatrick       // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
793*09467b48Spatrick       if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
794*09467b48Spatrick         Value *NewAnd = Builder.CreateAnd(Z, *C1);
795*09467b48Spatrick         return Builder.CreateSub(RHS, NewAnd, "sub");
796*09467b48Spatrick       } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) {
797*09467b48Spatrick         // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1))
798*09467b48Spatrick         // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1))
799*09467b48Spatrick         Value *NewOr = Builder.CreateOr(Z, ~(*C1));
800*09467b48Spatrick         return Builder.CreateSub(RHS, NewOr, "sub");
801*09467b48Spatrick       }
802*09467b48Spatrick     }
803*09467b48Spatrick   }
804*09467b48Spatrick 
805*09467b48Spatrick   // Restore LHS and RHS
806*09467b48Spatrick   LHS = I.getOperand(0);
807*09467b48Spatrick   RHS = I.getOperand(1);
808*09467b48Spatrick 
809*09467b48Spatrick   // if XOR is on other side, swap
810*09467b48Spatrick   if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
811*09467b48Spatrick     std::swap(LHS, RHS);
812*09467b48Spatrick 
813*09467b48Spatrick   // C2 is ODD
814*09467b48Spatrick   // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2))
815*09467b48Spatrick   // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
816*09467b48Spatrick   if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
817*09467b48Spatrick     if (C1->countTrailingZeros() == 0)
818*09467b48Spatrick       if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
819*09467b48Spatrick         Value *NewOr = Builder.CreateOr(Z, ~(*C2));
820*09467b48Spatrick         return Builder.CreateSub(RHS, NewOr, "sub");
821*09467b48Spatrick       }
822*09467b48Spatrick   return nullptr;
823*09467b48Spatrick }
824*09467b48Spatrick 
825*09467b48Spatrick /// Wrapping flags may allow combining constants separated by an extend.
826*09467b48Spatrick static Instruction *foldNoWrapAdd(BinaryOperator &Add,
827*09467b48Spatrick                                   InstCombiner::BuilderTy &Builder) {
828*09467b48Spatrick   Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1);
829*09467b48Spatrick   Type *Ty = Add.getType();
830*09467b48Spatrick   Constant *Op1C;
831*09467b48Spatrick   if (!match(Op1, m_Constant(Op1C)))
832*09467b48Spatrick     return nullptr;
833*09467b48Spatrick 
834*09467b48Spatrick   // Try this match first because it results in an add in the narrow type.
835*09467b48Spatrick   // (zext (X +nuw C2)) + C1 --> zext (X + (C2 + trunc(C1)))
836*09467b48Spatrick   Value *X;
837*09467b48Spatrick   const APInt *C1, *C2;
838*09467b48Spatrick   if (match(Op1, m_APInt(C1)) &&
839*09467b48Spatrick       match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) &&
840*09467b48Spatrick       C1->isNegative() && C1->sge(-C2->sext(C1->getBitWidth()))) {
841*09467b48Spatrick     Constant *NewC =
842*09467b48Spatrick         ConstantInt::get(X->getType(), *C2 + C1->trunc(C2->getBitWidth()));
843*09467b48Spatrick     return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty);
844*09467b48Spatrick   }
845*09467b48Spatrick 
846*09467b48Spatrick   // More general combining of constants in the wide type.
847*09467b48Spatrick   // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C)
848*09467b48Spatrick   Constant *NarrowC;
849*09467b48Spatrick   if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) {
850*09467b48Spatrick     Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty);
851*09467b48Spatrick     Constant *NewC = ConstantExpr::getAdd(WideC, Op1C);
852*09467b48Spatrick     Value *WideX = Builder.CreateSExt(X, Ty);
853*09467b48Spatrick     return BinaryOperator::CreateAdd(WideX, NewC);
854*09467b48Spatrick   }
855*09467b48Spatrick   // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C)
856*09467b48Spatrick   if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) {
857*09467b48Spatrick     Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty);
858*09467b48Spatrick     Constant *NewC = ConstantExpr::getAdd(WideC, Op1C);
859*09467b48Spatrick     Value *WideX = Builder.CreateZExt(X, Ty);
860*09467b48Spatrick     return BinaryOperator::CreateAdd(WideX, NewC);
861*09467b48Spatrick   }
862*09467b48Spatrick 
863*09467b48Spatrick   return nullptr;
864*09467b48Spatrick }
865*09467b48Spatrick 
866*09467b48Spatrick Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) {
867*09467b48Spatrick   Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1);
868*09467b48Spatrick   Constant *Op1C;
869*09467b48Spatrick   if (!match(Op1, m_Constant(Op1C)))
870*09467b48Spatrick     return nullptr;
871*09467b48Spatrick 
872*09467b48Spatrick   if (Instruction *NV = foldBinOpIntoSelectOrPhi(Add))
873*09467b48Spatrick     return NV;
874*09467b48Spatrick 
875*09467b48Spatrick   Value *X;
876*09467b48Spatrick   Constant *Op00C;
877*09467b48Spatrick 
878*09467b48Spatrick   // add (sub C1, X), C2 --> sub (add C1, C2), X
879*09467b48Spatrick   if (match(Op0, m_Sub(m_Constant(Op00C), m_Value(X))))
880*09467b48Spatrick     return BinaryOperator::CreateSub(ConstantExpr::getAdd(Op00C, Op1C), X);
881*09467b48Spatrick 
882*09467b48Spatrick   Value *Y;
883*09467b48Spatrick 
884*09467b48Spatrick   // add (sub X, Y), -1 --> add (not Y), X
885*09467b48Spatrick   if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y)))) &&
886*09467b48Spatrick       match(Op1, m_AllOnes()))
887*09467b48Spatrick     return BinaryOperator::CreateAdd(Builder.CreateNot(Y), X);
888*09467b48Spatrick 
889*09467b48Spatrick   // zext(bool) + C -> bool ? C + 1 : C
890*09467b48Spatrick   if (match(Op0, m_ZExt(m_Value(X))) &&
891*09467b48Spatrick       X->getType()->getScalarSizeInBits() == 1)
892*09467b48Spatrick     return SelectInst::Create(X, AddOne(Op1C), Op1);
893*09467b48Spatrick   // sext(bool) + C -> bool ? C - 1 : C
894*09467b48Spatrick   if (match(Op0, m_SExt(m_Value(X))) &&
895*09467b48Spatrick       X->getType()->getScalarSizeInBits() == 1)
896*09467b48Spatrick     return SelectInst::Create(X, SubOne(Op1C), Op1);
897*09467b48Spatrick 
898*09467b48Spatrick   // ~X + C --> (C-1) - X
899*09467b48Spatrick   if (match(Op0, m_Not(m_Value(X))))
900*09467b48Spatrick     return BinaryOperator::CreateSub(SubOne(Op1C), X);
901*09467b48Spatrick 
902*09467b48Spatrick   const APInt *C;
903*09467b48Spatrick   if (!match(Op1, m_APInt(C)))
904*09467b48Spatrick     return nullptr;
905*09467b48Spatrick 
906*09467b48Spatrick   // (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C)
907*09467b48Spatrick   const APInt *C2;
908*09467b48Spatrick   if (match(Op0, m_Or(m_Value(), m_APInt(C2))) && *C2 == -*C)
909*09467b48Spatrick     return BinaryOperator::CreateXor(Op0, ConstantInt::get(Add.getType(), *C2));
910*09467b48Spatrick 
911*09467b48Spatrick   if (C->isSignMask()) {
912*09467b48Spatrick     // If wrapping is not allowed, then the addition must set the sign bit:
913*09467b48Spatrick     // X + (signmask) --> X | signmask
914*09467b48Spatrick     if (Add.hasNoSignedWrap() || Add.hasNoUnsignedWrap())
915*09467b48Spatrick       return BinaryOperator::CreateOr(Op0, Op1);
916*09467b48Spatrick 
917*09467b48Spatrick     // If wrapping is allowed, then the addition flips the sign bit of LHS:
918*09467b48Spatrick     // X + (signmask) --> X ^ signmask
919*09467b48Spatrick     return BinaryOperator::CreateXor(Op0, Op1);
920*09467b48Spatrick   }
921*09467b48Spatrick 
922*09467b48Spatrick   // Is this add the last step in a convoluted sext?
923*09467b48Spatrick   // add(zext(xor i16 X, -32768), -32768) --> sext X
924*09467b48Spatrick   Type *Ty = Add.getType();
925*09467b48Spatrick   if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) &&
926*09467b48Spatrick       C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
927*09467b48Spatrick     return CastInst::Create(Instruction::SExt, X, Ty);
928*09467b48Spatrick 
929*09467b48Spatrick   if (C->isOneValue() && Op0->hasOneUse()) {
930*09467b48Spatrick     // add (sext i1 X), 1 --> zext (not X)
931*09467b48Spatrick     // TODO: The smallest IR representation is (select X, 0, 1), and that would
932*09467b48Spatrick     // not require the one-use check. But we need to remove a transform in
933*09467b48Spatrick     // visitSelect and make sure that IR value tracking for select is equal or
934*09467b48Spatrick     // better than for these ops.
935*09467b48Spatrick     if (match(Op0, m_SExt(m_Value(X))) &&
936*09467b48Spatrick         X->getType()->getScalarSizeInBits() == 1)
937*09467b48Spatrick       return new ZExtInst(Builder.CreateNot(X), Ty);
938*09467b48Spatrick 
939*09467b48Spatrick     // Shifts and add used to flip and mask off the low bit:
940*09467b48Spatrick     // add (ashr (shl i32 X, 31), 31), 1 --> and (not X), 1
941*09467b48Spatrick     const APInt *C3;
942*09467b48Spatrick     if (match(Op0, m_AShr(m_Shl(m_Value(X), m_APInt(C2)), m_APInt(C3))) &&
943*09467b48Spatrick         C2 == C3 && *C2 == Ty->getScalarSizeInBits() - 1) {
944*09467b48Spatrick       Value *NotX = Builder.CreateNot(X);
945*09467b48Spatrick       return BinaryOperator::CreateAnd(NotX, ConstantInt::get(Ty, 1));
946*09467b48Spatrick     }
947*09467b48Spatrick   }
948*09467b48Spatrick 
949*09467b48Spatrick   return nullptr;
950*09467b48Spatrick }
951*09467b48Spatrick 
952*09467b48Spatrick // Matches multiplication expression Op * C where C is a constant. Returns the
953*09467b48Spatrick // constant value in C and the other operand in Op. Returns true if such a
954*09467b48Spatrick // match is found.
955*09467b48Spatrick static bool MatchMul(Value *E, Value *&Op, APInt &C) {
956*09467b48Spatrick   const APInt *AI;
957*09467b48Spatrick   if (match(E, m_Mul(m_Value(Op), m_APInt(AI)))) {
958*09467b48Spatrick     C = *AI;
959*09467b48Spatrick     return true;
960*09467b48Spatrick   }
961*09467b48Spatrick   if (match(E, m_Shl(m_Value(Op), m_APInt(AI)))) {
962*09467b48Spatrick     C = APInt(AI->getBitWidth(), 1);
963*09467b48Spatrick     C <<= *AI;
964*09467b48Spatrick     return true;
965*09467b48Spatrick   }
966*09467b48Spatrick   return false;
967*09467b48Spatrick }
968*09467b48Spatrick 
969*09467b48Spatrick // Matches remainder expression Op % C where C is a constant. Returns the
970*09467b48Spatrick // constant value in C and the other operand in Op. Returns the signedness of
971*09467b48Spatrick // the remainder operation in IsSigned. Returns true if such a match is
972*09467b48Spatrick // found.
973*09467b48Spatrick static bool MatchRem(Value *E, Value *&Op, APInt &C, bool &IsSigned) {
974*09467b48Spatrick   const APInt *AI;
975*09467b48Spatrick   IsSigned = false;
976*09467b48Spatrick   if (match(E, m_SRem(m_Value(Op), m_APInt(AI)))) {
977*09467b48Spatrick     IsSigned = true;
978*09467b48Spatrick     C = *AI;
979*09467b48Spatrick     return true;
980*09467b48Spatrick   }
981*09467b48Spatrick   if (match(E, m_URem(m_Value(Op), m_APInt(AI)))) {
982*09467b48Spatrick     C = *AI;
983*09467b48Spatrick     return true;
984*09467b48Spatrick   }
985*09467b48Spatrick   if (match(E, m_And(m_Value(Op), m_APInt(AI))) && (*AI + 1).isPowerOf2()) {
986*09467b48Spatrick     C = *AI + 1;
987*09467b48Spatrick     return true;
988*09467b48Spatrick   }
989*09467b48Spatrick   return false;
990*09467b48Spatrick }
991*09467b48Spatrick 
992*09467b48Spatrick // Matches division expression Op / C with the given signedness as indicated
993*09467b48Spatrick // by IsSigned, where C is a constant. Returns the constant value in C and the
994*09467b48Spatrick // other operand in Op. Returns true if such a match is found.
995*09467b48Spatrick static bool MatchDiv(Value *E, Value *&Op, APInt &C, bool IsSigned) {
996*09467b48Spatrick   const APInt *AI;
997*09467b48Spatrick   if (IsSigned && match(E, m_SDiv(m_Value(Op), m_APInt(AI)))) {
998*09467b48Spatrick     C = *AI;
999*09467b48Spatrick     return true;
1000*09467b48Spatrick   }
1001*09467b48Spatrick   if (!IsSigned) {
1002*09467b48Spatrick     if (match(E, m_UDiv(m_Value(Op), m_APInt(AI)))) {
1003*09467b48Spatrick       C = *AI;
1004*09467b48Spatrick       return true;
1005*09467b48Spatrick     }
1006*09467b48Spatrick     if (match(E, m_LShr(m_Value(Op), m_APInt(AI)))) {
1007*09467b48Spatrick       C = APInt(AI->getBitWidth(), 1);
1008*09467b48Spatrick       C <<= *AI;
1009*09467b48Spatrick       return true;
1010*09467b48Spatrick     }
1011*09467b48Spatrick   }
1012*09467b48Spatrick   return false;
1013*09467b48Spatrick }
1014*09467b48Spatrick 
1015*09467b48Spatrick // Returns whether C0 * C1 with the given signedness overflows.
1016*09467b48Spatrick static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) {
1017*09467b48Spatrick   bool overflow;
1018*09467b48Spatrick   if (IsSigned)
1019*09467b48Spatrick     (void)C0.smul_ov(C1, overflow);
1020*09467b48Spatrick   else
1021*09467b48Spatrick     (void)C0.umul_ov(C1, overflow);
1022*09467b48Spatrick   return overflow;
1023*09467b48Spatrick }
1024*09467b48Spatrick 
1025*09467b48Spatrick // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
1026*09467b48Spatrick // does not overflow.
1027*09467b48Spatrick Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) {
1028*09467b48Spatrick   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
1029*09467b48Spatrick   Value *X, *MulOpV;
1030*09467b48Spatrick   APInt C0, MulOpC;
1031*09467b48Spatrick   bool IsSigned;
1032*09467b48Spatrick   // Match I = X % C0 + MulOpV * C0
1033*09467b48Spatrick   if (((MatchRem(LHS, X, C0, IsSigned) && MatchMul(RHS, MulOpV, MulOpC)) ||
1034*09467b48Spatrick        (MatchRem(RHS, X, C0, IsSigned) && MatchMul(LHS, MulOpV, MulOpC))) &&
1035*09467b48Spatrick       C0 == MulOpC) {
1036*09467b48Spatrick     Value *RemOpV;
1037*09467b48Spatrick     APInt C1;
1038*09467b48Spatrick     bool Rem2IsSigned;
1039*09467b48Spatrick     // Match MulOpC = RemOpV % C1
1040*09467b48Spatrick     if (MatchRem(MulOpV, RemOpV, C1, Rem2IsSigned) &&
1041*09467b48Spatrick         IsSigned == Rem2IsSigned) {
1042*09467b48Spatrick       Value *DivOpV;
1043*09467b48Spatrick       APInt DivOpC;
1044*09467b48Spatrick       // Match RemOpV = X / C0
1045*09467b48Spatrick       if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV &&
1046*09467b48Spatrick           C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) {
1047*09467b48Spatrick         Value *NewDivisor =
1048*09467b48Spatrick             ConstantInt::get(X->getType()->getContext(), C0 * C1);
1049*09467b48Spatrick         return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem")
1050*09467b48Spatrick                         : Builder.CreateURem(X, NewDivisor, "urem");
1051*09467b48Spatrick       }
1052*09467b48Spatrick     }
1053*09467b48Spatrick   }
1054*09467b48Spatrick 
1055*09467b48Spatrick   return nullptr;
1056*09467b48Spatrick }
1057*09467b48Spatrick 
1058*09467b48Spatrick /// Fold
1059*09467b48Spatrick ///   (1 << NBits) - 1
1060*09467b48Spatrick /// Into:
1061*09467b48Spatrick ///   ~(-(1 << NBits))
1062*09467b48Spatrick /// Because a 'not' is better for bit-tracking analysis and other transforms
1063*09467b48Spatrick /// than an 'add'. The new shl is always nsw, and is nuw if old `and` was.
1064*09467b48Spatrick static Instruction *canonicalizeLowbitMask(BinaryOperator &I,
1065*09467b48Spatrick                                            InstCombiner::BuilderTy &Builder) {
1066*09467b48Spatrick   Value *NBits;
1067*09467b48Spatrick   if (!match(&I, m_Add(m_OneUse(m_Shl(m_One(), m_Value(NBits))), m_AllOnes())))
1068*09467b48Spatrick     return nullptr;
1069*09467b48Spatrick 
1070*09467b48Spatrick   Constant *MinusOne = Constant::getAllOnesValue(NBits->getType());
1071*09467b48Spatrick   Value *NotMask = Builder.CreateShl(MinusOne, NBits, "notmask");
1072*09467b48Spatrick   // Be wary of constant folding.
1073*09467b48Spatrick   if (auto *BOp = dyn_cast<BinaryOperator>(NotMask)) {
1074*09467b48Spatrick     // Always NSW. But NUW propagates from `add`.
1075*09467b48Spatrick     BOp->setHasNoSignedWrap();
1076*09467b48Spatrick     BOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
1077*09467b48Spatrick   }
1078*09467b48Spatrick 
1079*09467b48Spatrick   return BinaryOperator::CreateNot(NotMask, I.getName());
1080*09467b48Spatrick }
1081*09467b48Spatrick 
1082*09467b48Spatrick static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
1083*09467b48Spatrick   assert(I.getOpcode() == Instruction::Add && "Expecting add instruction");
1084*09467b48Spatrick   Type *Ty = I.getType();
1085*09467b48Spatrick   auto getUAddSat = [&]() {
1086*09467b48Spatrick     return Intrinsic::getDeclaration(I.getModule(), Intrinsic::uadd_sat, Ty);
1087*09467b48Spatrick   };
1088*09467b48Spatrick 
1089*09467b48Spatrick   // add (umin X, ~Y), Y --> uaddsat X, Y
1090*09467b48Spatrick   Value *X, *Y;
1091*09467b48Spatrick   if (match(&I, m_c_Add(m_c_UMin(m_Value(X), m_Not(m_Value(Y))),
1092*09467b48Spatrick                         m_Deferred(Y))))
1093*09467b48Spatrick     return CallInst::Create(getUAddSat(), { X, Y });
1094*09467b48Spatrick 
1095*09467b48Spatrick   // add (umin X, ~C), C --> uaddsat X, C
1096*09467b48Spatrick   const APInt *C, *NotC;
1097*09467b48Spatrick   if (match(&I, m_Add(m_UMin(m_Value(X), m_APInt(NotC)), m_APInt(C))) &&
1098*09467b48Spatrick       *C == ~*NotC)
1099*09467b48Spatrick     return CallInst::Create(getUAddSat(), { X, ConstantInt::get(Ty, *C) });
1100*09467b48Spatrick 
1101*09467b48Spatrick   return nullptr;
1102*09467b48Spatrick }
1103*09467b48Spatrick 
1104*09467b48Spatrick Instruction *
1105*09467b48Spatrick InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
1106*09467b48Spatrick     BinaryOperator &I) {
1107*09467b48Spatrick   assert((I.getOpcode() == Instruction::Add ||
1108*09467b48Spatrick           I.getOpcode() == Instruction::Or ||
1109*09467b48Spatrick           I.getOpcode() == Instruction::Sub) &&
1110*09467b48Spatrick          "Expecting add/or/sub instruction");
1111*09467b48Spatrick 
1112*09467b48Spatrick   // We have a subtraction/addition between a (potentially truncated) *logical*
1113*09467b48Spatrick   // right-shift of X and a "select".
1114*09467b48Spatrick   Value *X, *Select;
1115*09467b48Spatrick   Instruction *LowBitsToSkip, *Extract;
1116*09467b48Spatrick   if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd(
1117*09467b48Spatrick                                m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)),
1118*09467b48Spatrick                                m_Instruction(Extract))),
1119*09467b48Spatrick                            m_Value(Select))))
1120*09467b48Spatrick     return nullptr;
1121*09467b48Spatrick 
1122*09467b48Spatrick   // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS.
1123*09467b48Spatrick   if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select)
1124*09467b48Spatrick     return nullptr;
1125*09467b48Spatrick 
1126*09467b48Spatrick   Type *XTy = X->getType();
1127*09467b48Spatrick   bool HadTrunc = I.getType() != XTy;
1128*09467b48Spatrick 
1129*09467b48Spatrick   // If there was a truncation of extracted value, then we'll need to produce
1130*09467b48Spatrick   // one extra instruction, so we need to ensure one instruction will go away.
1131*09467b48Spatrick   if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
1132*09467b48Spatrick     return nullptr;
1133*09467b48Spatrick 
1134*09467b48Spatrick   // Extraction should extract high NBits bits, with shift amount calculated as:
1135*09467b48Spatrick   //   low bits to skip = shift bitwidth - high bits to extract
1136*09467b48Spatrick   // The shift amount itself may be extended, and we need to look past zero-ext
1137*09467b48Spatrick   // when matching NBits, that will matter for matching later.
1138*09467b48Spatrick   Constant *C;
1139*09467b48Spatrick   Value *NBits;
1140*09467b48Spatrick   if (!match(
1141*09467b48Spatrick           LowBitsToSkip,
1142*09467b48Spatrick           m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) ||
1143*09467b48Spatrick       !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
1144*09467b48Spatrick                                    APInt(C->getType()->getScalarSizeInBits(),
1145*09467b48Spatrick                                          X->getType()->getScalarSizeInBits()))))
1146*09467b48Spatrick     return nullptr;
1147*09467b48Spatrick 
1148*09467b48Spatrick   // Sign-extending value can be zero-extended if we `sub`tract it,
1149*09467b48Spatrick   // or sign-extended otherwise.
1150*09467b48Spatrick   auto SkipExtInMagic = [&I](Value *&V) {
1151*09467b48Spatrick     if (I.getOpcode() == Instruction::Sub)
1152*09467b48Spatrick       match(V, m_ZExtOrSelf(m_Value(V)));
1153*09467b48Spatrick     else
1154*09467b48Spatrick       match(V, m_SExtOrSelf(m_Value(V)));
1155*09467b48Spatrick   };
1156*09467b48Spatrick 
1157*09467b48Spatrick   // Now, finally validate the sign-extending magic.
1158*09467b48Spatrick   // `select` itself may be appropriately extended, look past that.
1159*09467b48Spatrick   SkipExtInMagic(Select);
1160*09467b48Spatrick 
1161*09467b48Spatrick   ICmpInst::Predicate Pred;
1162*09467b48Spatrick   const APInt *Thr;
1163*09467b48Spatrick   Value *SignExtendingValue, *Zero;
1164*09467b48Spatrick   bool ShouldSignext;
1165*09467b48Spatrick   // It must be a select between two values we will later establish to be a
1166*09467b48Spatrick   // sign-extending value and a zero constant. The condition guarding the
1167*09467b48Spatrick   // sign-extension must be based on a sign bit of the same X we had in `lshr`.
1168*09467b48Spatrick   if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)),
1169*09467b48Spatrick                               m_Value(SignExtendingValue), m_Value(Zero))) ||
1170*09467b48Spatrick       !isSignBitCheck(Pred, *Thr, ShouldSignext))
1171*09467b48Spatrick     return nullptr;
1172*09467b48Spatrick 
1173*09467b48Spatrick   // icmp-select pair is commutative.
1174*09467b48Spatrick   if (!ShouldSignext)
1175*09467b48Spatrick     std::swap(SignExtendingValue, Zero);
1176*09467b48Spatrick 
1177*09467b48Spatrick   // If we should not perform sign-extension then we must add/or/subtract zero.
1178*09467b48Spatrick   if (!match(Zero, m_Zero()))
1179*09467b48Spatrick     return nullptr;
1180*09467b48Spatrick   // Otherwise, it should be some constant, left-shifted by the same NBits we
1181*09467b48Spatrick   // had in `lshr`. Said left-shift can also be appropriately extended.
1182*09467b48Spatrick   // Again, we must look past zero-ext when looking for NBits.
1183*09467b48Spatrick   SkipExtInMagic(SignExtendingValue);
1184*09467b48Spatrick   Constant *SignExtendingValueBaseConstant;
1185*09467b48Spatrick   if (!match(SignExtendingValue,
1186*09467b48Spatrick              m_Shl(m_Constant(SignExtendingValueBaseConstant),
1187*09467b48Spatrick                    m_ZExtOrSelf(m_Specific(NBits)))))
1188*09467b48Spatrick     return nullptr;
1189*09467b48Spatrick   // If we `sub`, then the constant should be one, else it should be all-ones.
1190*09467b48Spatrick   if (I.getOpcode() == Instruction::Sub
1191*09467b48Spatrick           ? !match(SignExtendingValueBaseConstant, m_One())
1192*09467b48Spatrick           : !match(SignExtendingValueBaseConstant, m_AllOnes()))
1193*09467b48Spatrick     return nullptr;
1194*09467b48Spatrick 
1195*09467b48Spatrick   auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip,
1196*09467b48Spatrick                                              Extract->getName() + ".sext");
1197*09467b48Spatrick   NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness.
1198*09467b48Spatrick   if (!HadTrunc)
1199*09467b48Spatrick     return NewAShr;
1200*09467b48Spatrick 
1201*09467b48Spatrick   Builder.Insert(NewAShr);
1202*09467b48Spatrick   return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType());
1203*09467b48Spatrick }
1204*09467b48Spatrick 
1205*09467b48Spatrick Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
1206*09467b48Spatrick   if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1),
1207*09467b48Spatrick                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
1208*09467b48Spatrick                                  SQ.getWithInstruction(&I)))
1209*09467b48Spatrick     return replaceInstUsesWith(I, V);
1210*09467b48Spatrick 
1211*09467b48Spatrick   if (SimplifyAssociativeOrCommutative(I))
1212*09467b48Spatrick     return &I;
1213*09467b48Spatrick 
1214*09467b48Spatrick   if (Instruction *X = foldVectorBinop(I))
1215*09467b48Spatrick     return X;
1216*09467b48Spatrick 
1217*09467b48Spatrick   // (A*B)+(A*C) -> A*(B+C) etc
1218*09467b48Spatrick   if (Value *V = SimplifyUsingDistributiveLaws(I))
1219*09467b48Spatrick     return replaceInstUsesWith(I, V);
1220*09467b48Spatrick 
1221*09467b48Spatrick   if (Instruction *X = foldAddWithConstant(I))
1222*09467b48Spatrick     return X;
1223*09467b48Spatrick 
1224*09467b48Spatrick   if (Instruction *X = foldNoWrapAdd(I, Builder))
1225*09467b48Spatrick     return X;
1226*09467b48Spatrick 
1227*09467b48Spatrick   // FIXME: This should be moved into the above helper function to allow these
1228*09467b48Spatrick   // transforms for general constant or constant splat vectors.
1229*09467b48Spatrick   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
1230*09467b48Spatrick   Type *Ty = I.getType();
1231*09467b48Spatrick   if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
1232*09467b48Spatrick     Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr;
1233*09467b48Spatrick     if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
1234*09467b48Spatrick       unsigned TySizeBits = Ty->getScalarSizeInBits();
1235*09467b48Spatrick       const APInt &RHSVal = CI->getValue();
1236*09467b48Spatrick       unsigned ExtendAmt = 0;
1237*09467b48Spatrick       // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext.
1238*09467b48Spatrick       // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext.
1239*09467b48Spatrick       if (XorRHS->getValue() == -RHSVal) {
1240*09467b48Spatrick         if (RHSVal.isPowerOf2())
1241*09467b48Spatrick           ExtendAmt = TySizeBits - RHSVal.logBase2() - 1;
1242*09467b48Spatrick         else if (XorRHS->getValue().isPowerOf2())
1243*09467b48Spatrick           ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
1244*09467b48Spatrick       }
1245*09467b48Spatrick 
1246*09467b48Spatrick       if (ExtendAmt) {
1247*09467b48Spatrick         APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
1248*09467b48Spatrick         if (!MaskedValueIsZero(XorLHS, Mask, 0, &I))
1249*09467b48Spatrick           ExtendAmt = 0;
1250*09467b48Spatrick       }
1251*09467b48Spatrick 
1252*09467b48Spatrick       if (ExtendAmt) {
1253*09467b48Spatrick         Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt);
1254*09467b48Spatrick         Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext");
1255*09467b48Spatrick         return BinaryOperator::CreateAShr(NewShl, ShAmt);
1256*09467b48Spatrick       }
1257*09467b48Spatrick 
1258*09467b48Spatrick       // If this is a xor that was canonicalized from a sub, turn it back into
1259*09467b48Spatrick       // a sub and fuse this add with it.
1260*09467b48Spatrick       if (LHS->hasOneUse() && (XorRHS->getValue()+1).isPowerOf2()) {
1261*09467b48Spatrick         KnownBits LHSKnown = computeKnownBits(XorLHS, 0, &I);
1262*09467b48Spatrick         if ((XorRHS->getValue() | LHSKnown.Zero).isAllOnesValue())
1263*09467b48Spatrick           return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
1264*09467b48Spatrick                                            XorLHS);
1265*09467b48Spatrick       }
1266*09467b48Spatrick       // (X + signmask) + C could have gotten canonicalized to (X^signmask) + C,
1267*09467b48Spatrick       // transform them into (X + (signmask ^ C))
1268*09467b48Spatrick       if (XorRHS->getValue().isSignMask())
1269*09467b48Spatrick         return BinaryOperator::CreateAdd(XorLHS,
1270*09467b48Spatrick                                          ConstantExpr::getXor(XorRHS, CI));
1271*09467b48Spatrick     }
1272*09467b48Spatrick   }
1273*09467b48Spatrick 
1274*09467b48Spatrick   if (Ty->isIntOrIntVectorTy(1))
1275*09467b48Spatrick     return BinaryOperator::CreateXor(LHS, RHS);
1276*09467b48Spatrick 
1277*09467b48Spatrick   // X + X --> X << 1
1278*09467b48Spatrick   if (LHS == RHS) {
1279*09467b48Spatrick     auto *Shl = BinaryOperator::CreateShl(LHS, ConstantInt::get(Ty, 1));
1280*09467b48Spatrick     Shl->setHasNoSignedWrap(I.hasNoSignedWrap());
1281*09467b48Spatrick     Shl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
1282*09467b48Spatrick     return Shl;
1283*09467b48Spatrick   }
1284*09467b48Spatrick 
1285*09467b48Spatrick   Value *A, *B;
1286*09467b48Spatrick   if (match(LHS, m_Neg(m_Value(A)))) {
1287*09467b48Spatrick     // -A + -B --> -(A + B)
1288*09467b48Spatrick     if (match(RHS, m_Neg(m_Value(B))))
1289*09467b48Spatrick       return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B));
1290*09467b48Spatrick 
1291*09467b48Spatrick     // -A + B --> B - A
1292*09467b48Spatrick     return BinaryOperator::CreateSub(RHS, A);
1293*09467b48Spatrick   }
1294*09467b48Spatrick 
1295*09467b48Spatrick   // A + -B  -->  A - B
1296*09467b48Spatrick   if (match(RHS, m_Neg(m_Value(B))))
1297*09467b48Spatrick     return BinaryOperator::CreateSub(LHS, B);
1298*09467b48Spatrick 
1299*09467b48Spatrick   if (Value *V = checkForNegativeOperand(I, Builder))
1300*09467b48Spatrick     return replaceInstUsesWith(I, V);
1301*09467b48Spatrick 
1302*09467b48Spatrick   // (A + 1) + ~B --> A - B
1303*09467b48Spatrick   // ~B + (A + 1) --> A - B
1304*09467b48Spatrick   // (~B + A) + 1 --> A - B
1305*09467b48Spatrick   // (A + ~B) + 1 --> A - B
1306*09467b48Spatrick   if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B)))) ||
1307*09467b48Spatrick       match(&I, m_BinOp(m_c_Add(m_Not(m_Value(B)), m_Value(A)), m_One())))
1308*09467b48Spatrick     return BinaryOperator::CreateSub(A, B);
1309*09467b48Spatrick 
1310*09467b48Spatrick   // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
1311*09467b48Spatrick   if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V);
1312*09467b48Spatrick 
1313*09467b48Spatrick   // A+B --> A|B iff A and B have no bits set in common.
1314*09467b48Spatrick   if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
1315*09467b48Spatrick     return BinaryOperator::CreateOr(LHS, RHS);
1316*09467b48Spatrick 
1317*09467b48Spatrick   // FIXME: We already did a check for ConstantInt RHS above this.
1318*09467b48Spatrick   // FIXME: Is this pattern covered by another fold? No regression tests fail on
1319*09467b48Spatrick   // removal.
1320*09467b48Spatrick   if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) {
1321*09467b48Spatrick     // (X & FF00) + xx00  -> (X+xx00) & FF00
1322*09467b48Spatrick     Value *X;
1323*09467b48Spatrick     ConstantInt *C2;
1324*09467b48Spatrick     if (LHS->hasOneUse() &&
1325*09467b48Spatrick         match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) &&
1326*09467b48Spatrick         CRHS->getValue() == (CRHS->getValue() & C2->getValue())) {
1327*09467b48Spatrick       // See if all bits from the first bit set in the Add RHS up are included
1328*09467b48Spatrick       // in the mask.  First, get the rightmost bit.
1329*09467b48Spatrick       const APInt &AddRHSV = CRHS->getValue();
1330*09467b48Spatrick 
1331*09467b48Spatrick       // Form a mask of all bits from the lowest bit added through the top.
1332*09467b48Spatrick       APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1));
1333*09467b48Spatrick 
1334*09467b48Spatrick       // See if the and mask includes all of these bits.
1335*09467b48Spatrick       APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue());
1336*09467b48Spatrick 
1337*09467b48Spatrick       if (AddRHSHighBits == AddRHSHighBitsAnd) {
1338*09467b48Spatrick         // Okay, the xform is safe.  Insert the new add pronto.
1339*09467b48Spatrick         Value *NewAdd = Builder.CreateAdd(X, CRHS, LHS->getName());
1340*09467b48Spatrick         return BinaryOperator::CreateAnd(NewAdd, C2);
1341*09467b48Spatrick       }
1342*09467b48Spatrick     }
1343*09467b48Spatrick   }
1344*09467b48Spatrick 
1345*09467b48Spatrick   // add (select X 0 (sub n A)) A  -->  select X A n
1346*09467b48Spatrick   {
1347*09467b48Spatrick     SelectInst *SI = dyn_cast<SelectInst>(LHS);
1348*09467b48Spatrick     Value *A = RHS;
1349*09467b48Spatrick     if (!SI) {
1350*09467b48Spatrick       SI = dyn_cast<SelectInst>(RHS);
1351*09467b48Spatrick       A = LHS;
1352*09467b48Spatrick     }
1353*09467b48Spatrick     if (SI && SI->hasOneUse()) {
1354*09467b48Spatrick       Value *TV = SI->getTrueValue();
1355*09467b48Spatrick       Value *FV = SI->getFalseValue();
1356*09467b48Spatrick       Value *N;
1357*09467b48Spatrick 
1358*09467b48Spatrick       // Can we fold the add into the argument of the select?
1359*09467b48Spatrick       // We check both true and false select arguments for a matching subtract.
1360*09467b48Spatrick       if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A))))
1361*09467b48Spatrick         // Fold the add into the true select value.
1362*09467b48Spatrick         return SelectInst::Create(SI->getCondition(), N, A);
1363*09467b48Spatrick 
1364*09467b48Spatrick       if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A))))
1365*09467b48Spatrick         // Fold the add into the false select value.
1366*09467b48Spatrick         return SelectInst::Create(SI->getCondition(), A, N);
1367*09467b48Spatrick     }
1368*09467b48Spatrick   }
1369*09467b48Spatrick 
1370*09467b48Spatrick   if (Instruction *Ext = narrowMathIfNoOverflow(I))
1371*09467b48Spatrick     return Ext;
1372*09467b48Spatrick 
1373*09467b48Spatrick   // (add (xor A, B) (and A, B)) --> (or A, B)
1374*09467b48Spatrick   // (add (and A, B) (xor A, B)) --> (or A, B)
1375*09467b48Spatrick   if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)),
1376*09467b48Spatrick                           m_c_And(m_Deferred(A), m_Deferred(B)))))
1377*09467b48Spatrick     return BinaryOperator::CreateOr(A, B);
1378*09467b48Spatrick 
1379*09467b48Spatrick   // (add (or A, B) (and A, B)) --> (add A, B)
1380*09467b48Spatrick   // (add (and A, B) (or A, B)) --> (add A, B)
1381*09467b48Spatrick   if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)),
1382*09467b48Spatrick                           m_c_And(m_Deferred(A), m_Deferred(B))))) {
1383*09467b48Spatrick     I.setOperand(0, A);
1384*09467b48Spatrick     I.setOperand(1, B);
1385*09467b48Spatrick     return &I;
1386*09467b48Spatrick   }
1387*09467b48Spatrick 
1388*09467b48Spatrick   // TODO(jingyue): Consider willNotOverflowSignedAdd and
1389*09467b48Spatrick   // willNotOverflowUnsignedAdd to reduce the number of invocations of
1390*09467b48Spatrick   // computeKnownBits.
1391*09467b48Spatrick   bool Changed = false;
1392*09467b48Spatrick   if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
1393*09467b48Spatrick     Changed = true;
1394*09467b48Spatrick     I.setHasNoSignedWrap(true);
1395*09467b48Spatrick   }
1396*09467b48Spatrick   if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
1397*09467b48Spatrick     Changed = true;
1398*09467b48Spatrick     I.setHasNoUnsignedWrap(true);
1399*09467b48Spatrick   }
1400*09467b48Spatrick 
1401*09467b48Spatrick   if (Instruction *V = canonicalizeLowbitMask(I, Builder))
1402*09467b48Spatrick     return V;
1403*09467b48Spatrick 
1404*09467b48Spatrick   if (Instruction *V =
1405*09467b48Spatrick           canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
1406*09467b48Spatrick     return V;
1407*09467b48Spatrick 
1408*09467b48Spatrick   if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I))
1409*09467b48Spatrick     return SatAdd;
1410*09467b48Spatrick 
1411*09467b48Spatrick   return Changed ? &I : nullptr;
1412*09467b48Spatrick }
1413*09467b48Spatrick 
1414*09467b48Spatrick /// Eliminate an op from a linear interpolation (lerp) pattern.
1415*09467b48Spatrick static Instruction *factorizeLerp(BinaryOperator &I,
1416*09467b48Spatrick                                   InstCombiner::BuilderTy &Builder) {
1417*09467b48Spatrick   Value *X, *Y, *Z;
1418*09467b48Spatrick   if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y),
1419*09467b48Spatrick                                             m_OneUse(m_FSub(m_FPOne(),
1420*09467b48Spatrick                                                             m_Value(Z))))),
1421*09467b48Spatrick                           m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z))))))
1422*09467b48Spatrick     return nullptr;
1423*09467b48Spatrick 
1424*09467b48Spatrick   // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants]
1425*09467b48Spatrick   Value *XY = Builder.CreateFSubFMF(X, Y, &I);
1426*09467b48Spatrick   Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I);
1427*09467b48Spatrick   return BinaryOperator::CreateFAddFMF(Y, MulZ, &I);
1428*09467b48Spatrick }
1429*09467b48Spatrick 
1430*09467b48Spatrick /// Factor a common operand out of fadd/fsub of fmul/fdiv.
1431*09467b48Spatrick static Instruction *factorizeFAddFSub(BinaryOperator &I,
1432*09467b48Spatrick                                       InstCombiner::BuilderTy &Builder) {
1433*09467b48Spatrick   assert((I.getOpcode() == Instruction::FAdd ||
1434*09467b48Spatrick           I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub");
1435*09467b48Spatrick   assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
1436*09467b48Spatrick          "FP factorization requires FMF");
1437*09467b48Spatrick 
1438*09467b48Spatrick   if (Instruction *Lerp = factorizeLerp(I, Builder))
1439*09467b48Spatrick     return Lerp;
1440*09467b48Spatrick 
1441*09467b48Spatrick   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1442*09467b48Spatrick   Value *X, *Y, *Z;
1443*09467b48Spatrick   bool IsFMul;
1444*09467b48Spatrick   if ((match(Op0, m_OneUse(m_FMul(m_Value(X), m_Value(Z)))) &&
1445*09467b48Spatrick        match(Op1, m_OneUse(m_c_FMul(m_Value(Y), m_Specific(Z))))) ||
1446*09467b48Spatrick       (match(Op0, m_OneUse(m_FMul(m_Value(Z), m_Value(X)))) &&
1447*09467b48Spatrick        match(Op1, m_OneUse(m_c_FMul(m_Value(Y), m_Specific(Z))))))
1448*09467b48Spatrick     IsFMul = true;
1449*09467b48Spatrick   else if (match(Op0, m_OneUse(m_FDiv(m_Value(X), m_Value(Z)))) &&
1450*09467b48Spatrick            match(Op1, m_OneUse(m_FDiv(m_Value(Y), m_Specific(Z)))))
1451*09467b48Spatrick     IsFMul = false;
1452*09467b48Spatrick   else
1453*09467b48Spatrick     return nullptr;
1454*09467b48Spatrick 
1455*09467b48Spatrick   // (X * Z) + (Y * Z) --> (X + Y) * Z
1456*09467b48Spatrick   // (X * Z) - (Y * Z) --> (X - Y) * Z
1457*09467b48Spatrick   // (X / Z) + (Y / Z) --> (X + Y) / Z
1458*09467b48Spatrick   // (X / Z) - (Y / Z) --> (X - Y) / Z
1459*09467b48Spatrick   bool IsFAdd = I.getOpcode() == Instruction::FAdd;
1460*09467b48Spatrick   Value *XY = IsFAdd ? Builder.CreateFAddFMF(X, Y, &I)
1461*09467b48Spatrick                      : Builder.CreateFSubFMF(X, Y, &I);
1462*09467b48Spatrick 
1463*09467b48Spatrick   // Bail out if we just created a denormal constant.
1464*09467b48Spatrick   // TODO: This is copied from a previous implementation. Is it necessary?
1465*09467b48Spatrick   const APFloat *C;
1466*09467b48Spatrick   if (match(XY, m_APFloat(C)) && !C->isNormal())
1467*09467b48Spatrick     return nullptr;
1468*09467b48Spatrick 
1469*09467b48Spatrick   return IsFMul ? BinaryOperator::CreateFMulFMF(XY, Z, &I)
1470*09467b48Spatrick                 : BinaryOperator::CreateFDivFMF(XY, Z, &I);
1471*09467b48Spatrick }
1472*09467b48Spatrick 
1473*09467b48Spatrick Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
1474*09467b48Spatrick   if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1),
1475*09467b48Spatrick                                   I.getFastMathFlags(),
1476*09467b48Spatrick                                   SQ.getWithInstruction(&I)))
1477*09467b48Spatrick     return replaceInstUsesWith(I, V);
1478*09467b48Spatrick 
1479*09467b48Spatrick   if (SimplifyAssociativeOrCommutative(I))
1480*09467b48Spatrick     return &I;
1481*09467b48Spatrick 
1482*09467b48Spatrick   if (Instruction *X = foldVectorBinop(I))
1483*09467b48Spatrick     return X;
1484*09467b48Spatrick 
1485*09467b48Spatrick   if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I))
1486*09467b48Spatrick     return FoldedFAdd;
1487*09467b48Spatrick 
1488*09467b48Spatrick   // (-X) + Y --> Y - X
1489*09467b48Spatrick   Value *X, *Y;
1490*09467b48Spatrick   if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y))))
1491*09467b48Spatrick     return BinaryOperator::CreateFSubFMF(Y, X, &I);
1492*09467b48Spatrick 
1493*09467b48Spatrick   // Similar to above, but look through fmul/fdiv for the negated term.
1494*09467b48Spatrick   // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants]
1495*09467b48Spatrick   Value *Z;
1496*09467b48Spatrick   if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))),
1497*09467b48Spatrick                          m_Value(Z)))) {
1498*09467b48Spatrick     Value *XY = Builder.CreateFMulFMF(X, Y, &I);
1499*09467b48Spatrick     return BinaryOperator::CreateFSubFMF(Z, XY, &I);
1500*09467b48Spatrick   }
1501*09467b48Spatrick   // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants]
1502*09467b48Spatrick   // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants]
1503*09467b48Spatrick   if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))),
1504*09467b48Spatrick                          m_Value(Z))) ||
1505*09467b48Spatrick       match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))),
1506*09467b48Spatrick                          m_Value(Z)))) {
1507*09467b48Spatrick     Value *XY = Builder.CreateFDivFMF(X, Y, &I);
1508*09467b48Spatrick     return BinaryOperator::CreateFSubFMF(Z, XY, &I);
1509*09467b48Spatrick   }
1510*09467b48Spatrick 
1511*09467b48Spatrick   // Check for (fadd double (sitofp x), y), see if we can merge this into an
1512*09467b48Spatrick   // integer add followed by a promotion.
1513*09467b48Spatrick   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
1514*09467b48Spatrick   if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
1515*09467b48Spatrick     Value *LHSIntVal = LHSConv->getOperand(0);
1516*09467b48Spatrick     Type *FPType = LHSConv->getType();
1517*09467b48Spatrick 
1518*09467b48Spatrick     // TODO: This check is overly conservative. In many cases known bits
1519*09467b48Spatrick     // analysis can tell us that the result of the addition has less significant
1520*09467b48Spatrick     // bits than the integer type can hold.
1521*09467b48Spatrick     auto IsValidPromotion = [](Type *FTy, Type *ITy) {
1522*09467b48Spatrick       Type *FScalarTy = FTy->getScalarType();
1523*09467b48Spatrick       Type *IScalarTy = ITy->getScalarType();
1524*09467b48Spatrick 
1525*09467b48Spatrick       // Do we have enough bits in the significand to represent the result of
1526*09467b48Spatrick       // the integer addition?
1527*09467b48Spatrick       unsigned MaxRepresentableBits =
1528*09467b48Spatrick           APFloat::semanticsPrecision(FScalarTy->getFltSemantics());
1529*09467b48Spatrick       return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits;
1530*09467b48Spatrick     };
1531*09467b48Spatrick 
1532*09467b48Spatrick     // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
1533*09467b48Spatrick     // ... if the constant fits in the integer value.  This is useful for things
1534*09467b48Spatrick     // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
1535*09467b48Spatrick     // requires a constant pool load, and generally allows the add to be better
1536*09467b48Spatrick     // instcombined.
1537*09467b48Spatrick     if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
1538*09467b48Spatrick       if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1539*09467b48Spatrick         Constant *CI =
1540*09467b48Spatrick           ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());
1541*09467b48Spatrick         if (LHSConv->hasOneUse() &&
1542*09467b48Spatrick             ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
1543*09467b48Spatrick             willNotOverflowSignedAdd(LHSIntVal, CI, I)) {
1544*09467b48Spatrick           // Insert the new integer add.
1545*09467b48Spatrick           Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv");
1546*09467b48Spatrick           return new SIToFPInst(NewAdd, I.getType());
1547*09467b48Spatrick         }
1548*09467b48Spatrick       }
1549*09467b48Spatrick 
1550*09467b48Spatrick     // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
1551*09467b48Spatrick     if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
1552*09467b48Spatrick       Value *RHSIntVal = RHSConv->getOperand(0);
1553*09467b48Spatrick       // It's enough to check LHS types only because we require int types to
1554*09467b48Spatrick       // be the same for this transform.
1555*09467b48Spatrick       if (IsValidPromotion(FPType, LHSIntVal->getType())) {
1556*09467b48Spatrick         // Only do this if x/y have the same type, if at least one of them has a
1557*09467b48Spatrick         // single use (so we don't increase the number of int->fp conversions),
1558*09467b48Spatrick         // and if the integer add will not overflow.
1559*09467b48Spatrick         if (LHSIntVal->getType() == RHSIntVal->getType() &&
1560*09467b48Spatrick             (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1561*09467b48Spatrick             willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) {
1562*09467b48Spatrick           // Insert the new integer add.
1563*09467b48Spatrick           Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal, "addconv");
1564*09467b48Spatrick           return new SIToFPInst(NewAdd, I.getType());
1565*09467b48Spatrick         }
1566*09467b48Spatrick       }
1567*09467b48Spatrick     }
1568*09467b48Spatrick   }
1569*09467b48Spatrick 
1570*09467b48Spatrick   // Handle specials cases for FAdd with selects feeding the operation
1571*09467b48Spatrick   if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS))
1572*09467b48Spatrick     return replaceInstUsesWith(I, V);
1573*09467b48Spatrick 
1574*09467b48Spatrick   if (I.hasAllowReassoc() && I.hasNoSignedZeros()) {
1575*09467b48Spatrick     if (Instruction *F = factorizeFAddFSub(I, Builder))
1576*09467b48Spatrick       return F;
1577*09467b48Spatrick     if (Value *V = FAddCombine(Builder).simplify(&I))
1578*09467b48Spatrick       return replaceInstUsesWith(I, V);
1579*09467b48Spatrick   }
1580*09467b48Spatrick 
1581*09467b48Spatrick   return nullptr;
1582*09467b48Spatrick }
1583*09467b48Spatrick 
1584*09467b48Spatrick /// Optimize pointer differences into the same array into a size.  Consider:
1585*09467b48Spatrick ///  &A[10] - &A[0]: we should compile this to "10".  LHS/RHS are the pointer
1586*09467b48Spatrick /// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
1587*09467b48Spatrick Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
1588*09467b48Spatrick                                                Type *Ty, bool IsNUW) {
1589*09467b48Spatrick   // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
1590*09467b48Spatrick   // this.
1591*09467b48Spatrick   bool Swapped = false;
1592*09467b48Spatrick   GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
1593*09467b48Spatrick 
1594*09467b48Spatrick   // For now we require one side to be the base pointer "A" or a constant
1595*09467b48Spatrick   // GEP derived from it.
1596*09467b48Spatrick   if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
1597*09467b48Spatrick     // (gep X, ...) - X
1598*09467b48Spatrick     if (LHSGEP->getOperand(0) == RHS) {
1599*09467b48Spatrick       GEP1 = LHSGEP;
1600*09467b48Spatrick       Swapped = false;
1601*09467b48Spatrick     } else if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
1602*09467b48Spatrick       // (gep X, ...) - (gep X, ...)
1603*09467b48Spatrick       if (LHSGEP->getOperand(0)->stripPointerCasts() ==
1604*09467b48Spatrick             RHSGEP->getOperand(0)->stripPointerCasts()) {
1605*09467b48Spatrick         GEP2 = RHSGEP;
1606*09467b48Spatrick         GEP1 = LHSGEP;
1607*09467b48Spatrick         Swapped = false;
1608*09467b48Spatrick       }
1609*09467b48Spatrick     }
1610*09467b48Spatrick   }
1611*09467b48Spatrick 
1612*09467b48Spatrick   if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
1613*09467b48Spatrick     // X - (gep X, ...)
1614*09467b48Spatrick     if (RHSGEP->getOperand(0) == LHS) {
1615*09467b48Spatrick       GEP1 = RHSGEP;
1616*09467b48Spatrick       Swapped = true;
1617*09467b48Spatrick     } else if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
1618*09467b48Spatrick       // (gep X, ...) - (gep X, ...)
1619*09467b48Spatrick       if (RHSGEP->getOperand(0)->stripPointerCasts() ==
1620*09467b48Spatrick             LHSGEP->getOperand(0)->stripPointerCasts()) {
1621*09467b48Spatrick         GEP2 = LHSGEP;
1622*09467b48Spatrick         GEP1 = RHSGEP;
1623*09467b48Spatrick         Swapped = true;
1624*09467b48Spatrick       }
1625*09467b48Spatrick     }
1626*09467b48Spatrick   }
1627*09467b48Spatrick 
1628*09467b48Spatrick   if (!GEP1)
1629*09467b48Spatrick     // No GEP found.
1630*09467b48Spatrick     return nullptr;
1631*09467b48Spatrick 
1632*09467b48Spatrick   if (GEP2) {
1633*09467b48Spatrick     // (gep X, ...) - (gep X, ...)
1634*09467b48Spatrick     //
1635*09467b48Spatrick     // Avoid duplicating the arithmetic if there are more than one non-constant
1636*09467b48Spatrick     // indices between the two GEPs and either GEP has a non-constant index and
1637*09467b48Spatrick     // multiple users. If zero non-constant index, the result is a constant and
1638*09467b48Spatrick     // there is no duplication. If one non-constant index, the result is an add
1639*09467b48Spatrick     // or sub with a constant, which is no larger than the original code, and
1640*09467b48Spatrick     // there's no duplicated arithmetic, even if either GEP has multiple
1641*09467b48Spatrick     // users. If more than one non-constant indices combined, as long as the GEP
1642*09467b48Spatrick     // with at least one non-constant index doesn't have multiple users, there
1643*09467b48Spatrick     // is no duplication.
1644*09467b48Spatrick     unsigned NumNonConstantIndices1 = GEP1->countNonConstantIndices();
1645*09467b48Spatrick     unsigned NumNonConstantIndices2 = GEP2->countNonConstantIndices();
1646*09467b48Spatrick     if (NumNonConstantIndices1 + NumNonConstantIndices2 > 1 &&
1647*09467b48Spatrick         ((NumNonConstantIndices1 > 0 && !GEP1->hasOneUse()) ||
1648*09467b48Spatrick          (NumNonConstantIndices2 > 0 && !GEP2->hasOneUse()))) {
1649*09467b48Spatrick       return nullptr;
1650*09467b48Spatrick     }
1651*09467b48Spatrick   }
1652*09467b48Spatrick 
1653*09467b48Spatrick   // Emit the offset of the GEP and an intptr_t.
1654*09467b48Spatrick   Value *Result = EmitGEPOffset(GEP1);
1655*09467b48Spatrick 
1656*09467b48Spatrick   // If this is a single inbounds GEP and the original sub was nuw,
1657*09467b48Spatrick   // then the final multiplication is also nuw. We match an extra add zero
1658*09467b48Spatrick   // here, because that's what EmitGEPOffset() generates.
1659*09467b48Spatrick   Instruction *I;
1660*09467b48Spatrick   if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() &&
1661*09467b48Spatrick       match(Result, m_Add(m_Instruction(I), m_Zero())) &&
1662*09467b48Spatrick       I->getOpcode() == Instruction::Mul)
1663*09467b48Spatrick     I->setHasNoUnsignedWrap();
1664*09467b48Spatrick 
1665*09467b48Spatrick   // If we had a constant expression GEP on the other side offsetting the
1666*09467b48Spatrick   // pointer, subtract it from the offset we have.
1667*09467b48Spatrick   if (GEP2) {
1668*09467b48Spatrick     Value *Offset = EmitGEPOffset(GEP2);
1669*09467b48Spatrick     Result = Builder.CreateSub(Result, Offset);
1670*09467b48Spatrick   }
1671*09467b48Spatrick 
1672*09467b48Spatrick   // If we have p - gep(p, ...)  then we have to negate the result.
1673*09467b48Spatrick   if (Swapped)
1674*09467b48Spatrick     Result = Builder.CreateNeg(Result, "diff.neg");
1675*09467b48Spatrick 
1676*09467b48Spatrick   return Builder.CreateIntCast(Result, Ty, true);
1677*09467b48Spatrick }
1678*09467b48Spatrick 
1679*09467b48Spatrick Instruction *InstCombiner::visitSub(BinaryOperator &I) {
1680*09467b48Spatrick   if (Value *V = SimplifySubInst(I.getOperand(0), I.getOperand(1),
1681*09467b48Spatrick                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
1682*09467b48Spatrick                                  SQ.getWithInstruction(&I)))
1683*09467b48Spatrick     return replaceInstUsesWith(I, V);
1684*09467b48Spatrick 
1685*09467b48Spatrick   if (Instruction *X = foldVectorBinop(I))
1686*09467b48Spatrick     return X;
1687*09467b48Spatrick 
1688*09467b48Spatrick   // (A*B)-(A*C) -> A*(B-C) etc
1689*09467b48Spatrick   if (Value *V = SimplifyUsingDistributiveLaws(I))
1690*09467b48Spatrick     return replaceInstUsesWith(I, V);
1691*09467b48Spatrick 
1692*09467b48Spatrick   // If this is a 'B = x-(-A)', change to B = x+A.
1693*09467b48Spatrick   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1694*09467b48Spatrick   if (Value *V = dyn_castNegVal(Op1)) {
1695*09467b48Spatrick     BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V);
1696*09467b48Spatrick 
1697*09467b48Spatrick     if (const auto *BO = dyn_cast<BinaryOperator>(Op1)) {
1698*09467b48Spatrick       assert(BO->getOpcode() == Instruction::Sub &&
1699*09467b48Spatrick              "Expected a subtraction operator!");
1700*09467b48Spatrick       if (BO->hasNoSignedWrap() && I.hasNoSignedWrap())
1701*09467b48Spatrick         Res->setHasNoSignedWrap(true);
1702*09467b48Spatrick     } else {
1703*09467b48Spatrick       if (cast<Constant>(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap())
1704*09467b48Spatrick         Res->setHasNoSignedWrap(true);
1705*09467b48Spatrick     }
1706*09467b48Spatrick 
1707*09467b48Spatrick     return Res;
1708*09467b48Spatrick   }
1709*09467b48Spatrick 
1710*09467b48Spatrick   if (I.getType()->isIntOrIntVectorTy(1))
1711*09467b48Spatrick     return BinaryOperator::CreateXor(Op0, Op1);
1712*09467b48Spatrick 
1713*09467b48Spatrick   // Replace (-1 - A) with (~A).
1714*09467b48Spatrick   if (match(Op0, m_AllOnes()))
1715*09467b48Spatrick     return BinaryOperator::CreateNot(Op1);
1716*09467b48Spatrick 
1717*09467b48Spatrick   // (~X) - (~Y) --> Y - X
1718*09467b48Spatrick   Value *X, *Y;
1719*09467b48Spatrick   if (match(Op0, m_Not(m_Value(X))) && match(Op1, m_Not(m_Value(Y))))
1720*09467b48Spatrick     return BinaryOperator::CreateSub(Y, X);
1721*09467b48Spatrick 
1722*09467b48Spatrick   // (X + -1) - Y --> ~Y + X
1723*09467b48Spatrick   if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes()))))
1724*09467b48Spatrick     return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X);
1725*09467b48Spatrick 
1726*09467b48Spatrick   // Y - (X + 1) --> ~X + Y
1727*09467b48Spatrick   if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One()))))
1728*09467b48Spatrick     return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0);
1729*09467b48Spatrick 
1730*09467b48Spatrick   // Y - ~X --> (X + 1) + Y
1731*09467b48Spatrick   if (match(Op1, m_OneUse(m_Not(m_Value(X))))) {
1732*09467b48Spatrick     return BinaryOperator::CreateAdd(
1733*09467b48Spatrick         Builder.CreateAdd(Op0, ConstantInt::get(I.getType(), 1)), X);
1734*09467b48Spatrick   }
1735*09467b48Spatrick 
1736*09467b48Spatrick   if (Constant *C = dyn_cast<Constant>(Op0)) {
1737*09467b48Spatrick     bool IsNegate = match(C, m_ZeroInt());
1738*09467b48Spatrick     Value *X;
1739*09467b48Spatrick     if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1740*09467b48Spatrick       // 0 - (zext bool) --> sext bool
1741*09467b48Spatrick       // C - (zext bool) --> bool ? C - 1 : C
1742*09467b48Spatrick       if (IsNegate)
1743*09467b48Spatrick         return CastInst::CreateSExtOrBitCast(X, I.getType());
1744*09467b48Spatrick       return SelectInst::Create(X, SubOne(C), C);
1745*09467b48Spatrick     }
1746*09467b48Spatrick     if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1747*09467b48Spatrick       // 0 - (sext bool) --> zext bool
1748*09467b48Spatrick       // C - (sext bool) --> bool ? C + 1 : C
1749*09467b48Spatrick       if (IsNegate)
1750*09467b48Spatrick         return CastInst::CreateZExtOrBitCast(X, I.getType());
1751*09467b48Spatrick       return SelectInst::Create(X, AddOne(C), C);
1752*09467b48Spatrick     }
1753*09467b48Spatrick 
1754*09467b48Spatrick     // C - ~X == X + (1+C)
1755*09467b48Spatrick     if (match(Op1, m_Not(m_Value(X))))
1756*09467b48Spatrick       return BinaryOperator::CreateAdd(X, AddOne(C));
1757*09467b48Spatrick 
1758*09467b48Spatrick     // Try to fold constant sub into select arguments.
1759*09467b48Spatrick     if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1760*09467b48Spatrick       if (Instruction *R = FoldOpIntoSelect(I, SI))
1761*09467b48Spatrick         return R;
1762*09467b48Spatrick 
1763*09467b48Spatrick     // Try to fold constant sub into PHI values.
1764*09467b48Spatrick     if (PHINode *PN = dyn_cast<PHINode>(Op1))
1765*09467b48Spatrick       if (Instruction *R = foldOpIntoPhi(I, PN))
1766*09467b48Spatrick         return R;
1767*09467b48Spatrick 
1768*09467b48Spatrick     Constant *C2;
1769*09467b48Spatrick 
1770*09467b48Spatrick     // C-(C2-X) --> X+(C-C2)
1771*09467b48Spatrick     if (match(Op1, m_Sub(m_Constant(C2), m_Value(X))))
1772*09467b48Spatrick       return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2));
1773*09467b48Spatrick 
1774*09467b48Spatrick     // C-(X+C2) --> (C-C2)-X
1775*09467b48Spatrick     if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
1776*09467b48Spatrick       return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
1777*09467b48Spatrick   }
1778*09467b48Spatrick 
1779*09467b48Spatrick   const APInt *Op0C;
1780*09467b48Spatrick   if (match(Op0, m_APInt(Op0C))) {
1781*09467b48Spatrick 
1782*09467b48Spatrick     if (Op0C->isNullValue()) {
1783*09467b48Spatrick       Value *Op1Wide;
1784*09467b48Spatrick       match(Op1, m_TruncOrSelf(m_Value(Op1Wide)));
1785*09467b48Spatrick       bool HadTrunc = Op1Wide != Op1;
1786*09467b48Spatrick       bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse();
1787*09467b48Spatrick       unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits();
1788*09467b48Spatrick 
1789*09467b48Spatrick       Value *X;
1790*09467b48Spatrick       const APInt *ShAmt;
1791*09467b48Spatrick       // -(X >>u 31) -> (X >>s 31)
1792*09467b48Spatrick       if (NoTruncOrTruncIsOneUse &&
1793*09467b48Spatrick           match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
1794*09467b48Spatrick           *ShAmt == BitWidth - 1) {
1795*09467b48Spatrick         Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
1796*09467b48Spatrick         Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp);
1797*09467b48Spatrick         NewShift->copyIRFlags(Op1Wide);
1798*09467b48Spatrick         if (!HadTrunc)
1799*09467b48Spatrick           return NewShift;
1800*09467b48Spatrick         Builder.Insert(NewShift);
1801*09467b48Spatrick         return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
1802*09467b48Spatrick       }
1803*09467b48Spatrick       // -(X >>s 31) -> (X >>u 31)
1804*09467b48Spatrick       if (NoTruncOrTruncIsOneUse &&
1805*09467b48Spatrick           match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
1806*09467b48Spatrick           *ShAmt == BitWidth - 1) {
1807*09467b48Spatrick         Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
1808*09467b48Spatrick         Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp);
1809*09467b48Spatrick         NewShift->copyIRFlags(Op1Wide);
1810*09467b48Spatrick         if (!HadTrunc)
1811*09467b48Spatrick           return NewShift;
1812*09467b48Spatrick         Builder.Insert(NewShift);
1813*09467b48Spatrick         return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
1814*09467b48Spatrick       }
1815*09467b48Spatrick 
1816*09467b48Spatrick       if (!HadTrunc && Op1->hasOneUse()) {
1817*09467b48Spatrick         Value *LHS, *RHS;
1818*09467b48Spatrick         SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
1819*09467b48Spatrick         if (SPF == SPF_ABS || SPF == SPF_NABS) {
1820*09467b48Spatrick           // This is a negate of an ABS/NABS pattern. Just swap the operands
1821*09467b48Spatrick           // of the select.
1822*09467b48Spatrick           cast<SelectInst>(Op1)->swapValues();
1823*09467b48Spatrick           // Don't swap prof metadata, we didn't change the branch behavior.
1824*09467b48Spatrick           return replaceInstUsesWith(I, Op1);
1825*09467b48Spatrick         }
1826*09467b48Spatrick       }
1827*09467b48Spatrick     }
1828*09467b48Spatrick 
1829*09467b48Spatrick     // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
1830*09467b48Spatrick     // zero.
1831*09467b48Spatrick     if (Op0C->isMask()) {
1832*09467b48Spatrick       KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
1833*09467b48Spatrick       if ((*Op0C | RHSKnown.Zero).isAllOnesValue())
1834*09467b48Spatrick         return BinaryOperator::CreateXor(Op1, Op0);
1835*09467b48Spatrick     }
1836*09467b48Spatrick   }
1837*09467b48Spatrick 
1838*09467b48Spatrick   {
1839*09467b48Spatrick     Value *Y;
1840*09467b48Spatrick     // X-(X+Y) == -Y    X-(Y+X) == -Y
1841*09467b48Spatrick     if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y))))
1842*09467b48Spatrick       return BinaryOperator::CreateNeg(Y);
1843*09467b48Spatrick 
1844*09467b48Spatrick     // (X-Y)-X == -Y
1845*09467b48Spatrick     if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y))))
1846*09467b48Spatrick       return BinaryOperator::CreateNeg(Y);
1847*09467b48Spatrick   }
1848*09467b48Spatrick 
1849*09467b48Spatrick   // (sub (or A, B) (and A, B)) --> (xor A, B)
1850*09467b48Spatrick   {
1851*09467b48Spatrick     Value *A, *B;
1852*09467b48Spatrick     if (match(Op1, m_And(m_Value(A), m_Value(B))) &&
1853*09467b48Spatrick         match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
1854*09467b48Spatrick       return BinaryOperator::CreateXor(A, B);
1855*09467b48Spatrick   }
1856*09467b48Spatrick 
1857*09467b48Spatrick   // (sub (and A, B) (or A, B)) --> neg (xor A, B)
1858*09467b48Spatrick   {
1859*09467b48Spatrick     Value *A, *B;
1860*09467b48Spatrick     if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
1861*09467b48Spatrick         match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
1862*09467b48Spatrick         (Op0->hasOneUse() || Op1->hasOneUse()))
1863*09467b48Spatrick       return BinaryOperator::CreateNeg(Builder.CreateXor(A, B));
1864*09467b48Spatrick   }
1865*09467b48Spatrick 
1866*09467b48Spatrick   // (sub (or A, B), (xor A, B)) --> (and A, B)
1867*09467b48Spatrick   {
1868*09467b48Spatrick     Value *A, *B;
1869*09467b48Spatrick     if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
1870*09467b48Spatrick         match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
1871*09467b48Spatrick       return BinaryOperator::CreateAnd(A, B);
1872*09467b48Spatrick   }
1873*09467b48Spatrick 
1874*09467b48Spatrick   // (sub (xor A, B) (or A, B)) --> neg (and A, B)
1875*09467b48Spatrick   {
1876*09467b48Spatrick     Value *A, *B;
1877*09467b48Spatrick     if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
1878*09467b48Spatrick         match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
1879*09467b48Spatrick         (Op0->hasOneUse() || Op1->hasOneUse()))
1880*09467b48Spatrick       return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B));
1881*09467b48Spatrick   }
1882*09467b48Spatrick 
1883*09467b48Spatrick   {
1884*09467b48Spatrick     Value *Y;
1885*09467b48Spatrick     // ((X | Y) - X) --> (~X & Y)
1886*09467b48Spatrick     if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1)))))
1887*09467b48Spatrick       return BinaryOperator::CreateAnd(
1888*09467b48Spatrick           Y, Builder.CreateNot(Op1, Op1->getName() + ".not"));
1889*09467b48Spatrick   }
1890*09467b48Spatrick 
1891*09467b48Spatrick   {
1892*09467b48Spatrick     // (sub (and Op1, (neg X)), Op1) --> neg (and Op1, (add X, -1))
1893*09467b48Spatrick     Value *X;
1894*09467b48Spatrick     if (match(Op0, m_OneUse(m_c_And(m_Specific(Op1),
1895*09467b48Spatrick                                     m_OneUse(m_Neg(m_Value(X))))))) {
1896*09467b48Spatrick       return BinaryOperator::CreateNeg(Builder.CreateAnd(
1897*09467b48Spatrick           Op1, Builder.CreateAdd(X, Constant::getAllOnesValue(I.getType()))));
1898*09467b48Spatrick     }
1899*09467b48Spatrick   }
1900*09467b48Spatrick 
1901*09467b48Spatrick   {
1902*09467b48Spatrick     // (sub (and Op1, C), Op1) --> neg (and Op1, ~C)
1903*09467b48Spatrick     Constant *C;
1904*09467b48Spatrick     if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_Constant(C))))) {
1905*09467b48Spatrick       return BinaryOperator::CreateNeg(
1906*09467b48Spatrick           Builder.CreateAnd(Op1, Builder.CreateNot(C)));
1907*09467b48Spatrick     }
1908*09467b48Spatrick   }
1909*09467b48Spatrick 
1910*09467b48Spatrick   {
1911*09467b48Spatrick     // If we have a subtraction between some value and a select between
1912*09467b48Spatrick     // said value and something else, sink subtraction into select hands, i.e.:
1913*09467b48Spatrick     //   sub (select %Cond, %TrueVal, %FalseVal), %Op1
1914*09467b48Spatrick     //     ->
1915*09467b48Spatrick     //   select %Cond, (sub %TrueVal, %Op1), (sub %FalseVal, %Op1)
1916*09467b48Spatrick     //  or
1917*09467b48Spatrick     //   sub %Op0, (select %Cond, %TrueVal, %FalseVal)
1918*09467b48Spatrick     //     ->
1919*09467b48Spatrick     //   select %Cond, (sub %Op0, %TrueVal), (sub %Op0, %FalseVal)
1920*09467b48Spatrick     // This will result in select between new subtraction and 0.
1921*09467b48Spatrick     auto SinkSubIntoSelect =
1922*09467b48Spatrick         [Ty = I.getType()](Value *Select, Value *OtherHandOfSub,
1923*09467b48Spatrick                            auto SubBuilder) -> Instruction * {
1924*09467b48Spatrick       Value *Cond, *TrueVal, *FalseVal;
1925*09467b48Spatrick       if (!match(Select, m_OneUse(m_Select(m_Value(Cond), m_Value(TrueVal),
1926*09467b48Spatrick                                            m_Value(FalseVal)))))
1927*09467b48Spatrick         return nullptr;
1928*09467b48Spatrick       if (OtherHandOfSub != TrueVal && OtherHandOfSub != FalseVal)
1929*09467b48Spatrick         return nullptr;
1930*09467b48Spatrick       // While it is really tempting to just create two subtractions and let
1931*09467b48Spatrick       // InstCombine fold one of those to 0, it isn't possible to do so
1932*09467b48Spatrick       // because of worklist visitation order. So ugly it is.
1933*09467b48Spatrick       bool OtherHandOfSubIsTrueVal = OtherHandOfSub == TrueVal;
1934*09467b48Spatrick       Value *NewSub = SubBuilder(OtherHandOfSubIsTrueVal ? FalseVal : TrueVal);
1935*09467b48Spatrick       Constant *Zero = Constant::getNullValue(Ty);
1936*09467b48Spatrick       SelectInst *NewSel =
1937*09467b48Spatrick           SelectInst::Create(Cond, OtherHandOfSubIsTrueVal ? Zero : NewSub,
1938*09467b48Spatrick                              OtherHandOfSubIsTrueVal ? NewSub : Zero);
1939*09467b48Spatrick       // Preserve prof metadata if any.
1940*09467b48Spatrick       NewSel->copyMetadata(cast<Instruction>(*Select));
1941*09467b48Spatrick       return NewSel;
1942*09467b48Spatrick     };
1943*09467b48Spatrick     if (Instruction *NewSel = SinkSubIntoSelect(
1944*09467b48Spatrick             /*Select=*/Op0, /*OtherHandOfSub=*/Op1,
1945*09467b48Spatrick             [Builder = &Builder, Op1](Value *OtherHandOfSelect) {
1946*09467b48Spatrick               return Builder->CreateSub(OtherHandOfSelect,
1947*09467b48Spatrick                                         /*OtherHandOfSub=*/Op1);
1948*09467b48Spatrick             }))
1949*09467b48Spatrick       return NewSel;
1950*09467b48Spatrick     if (Instruction *NewSel = SinkSubIntoSelect(
1951*09467b48Spatrick             /*Select=*/Op1, /*OtherHandOfSub=*/Op0,
1952*09467b48Spatrick             [Builder = &Builder, Op0](Value *OtherHandOfSelect) {
1953*09467b48Spatrick               return Builder->CreateSub(/*OtherHandOfSub=*/Op0,
1954*09467b48Spatrick                                         OtherHandOfSelect);
1955*09467b48Spatrick             }))
1956*09467b48Spatrick       return NewSel;
1957*09467b48Spatrick   }
1958*09467b48Spatrick 
1959*09467b48Spatrick   if (Op1->hasOneUse()) {
1960*09467b48Spatrick     Value *X = nullptr, *Y = nullptr, *Z = nullptr;
1961*09467b48Spatrick     Constant *C = nullptr;
1962*09467b48Spatrick 
1963*09467b48Spatrick     // (X - (Y - Z))  -->  (X + (Z - Y)).
1964*09467b48Spatrick     if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
1965*09467b48Spatrick       return BinaryOperator::CreateAdd(Op0,
1966*09467b48Spatrick                                       Builder.CreateSub(Z, Y, Op1->getName()));
1967*09467b48Spatrick 
1968*09467b48Spatrick     // (X - (X & Y))   -->   (X & ~Y)
1969*09467b48Spatrick     if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0))))
1970*09467b48Spatrick       return BinaryOperator::CreateAnd(Op0,
1971*09467b48Spatrick                                   Builder.CreateNot(Y, Y->getName() + ".not"));
1972*09467b48Spatrick 
1973*09467b48Spatrick     // 0 - (X sdiv C)  -> (X sdiv -C)  provided the negation doesn't overflow.
1974*09467b48Spatrick     if (match(Op0, m_Zero())) {
1975*09467b48Spatrick       Constant *Op11C;
1976*09467b48Spatrick       if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) &&
1977*09467b48Spatrick           !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() &&
1978*09467b48Spatrick           Op11C->isNotOneValue()) {
1979*09467b48Spatrick         Instruction *BO =
1980*09467b48Spatrick             BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C));
1981*09467b48Spatrick         BO->setIsExact(cast<BinaryOperator>(Op1)->isExact());
1982*09467b48Spatrick         return BO;
1983*09467b48Spatrick       }
1984*09467b48Spatrick     }
1985*09467b48Spatrick 
1986*09467b48Spatrick     // 0 - (X << Y)  -> (-X << Y)   when X is freely negatable.
1987*09467b48Spatrick     if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero()))
1988*09467b48Spatrick       if (Value *XNeg = dyn_castNegVal(X))
1989*09467b48Spatrick         return BinaryOperator::CreateShl(XNeg, Y);
1990*09467b48Spatrick 
1991*09467b48Spatrick     // Subtracting -1/0 is the same as adding 1/0:
1992*09467b48Spatrick     // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y)
1993*09467b48Spatrick     // 'nuw' is dropped in favor of the canonical form.
1994*09467b48Spatrick     if (match(Op1, m_SExt(m_Value(Y))) &&
1995*09467b48Spatrick         Y->getType()->getScalarSizeInBits() == 1) {
1996*09467b48Spatrick       Value *Zext = Builder.CreateZExt(Y, I.getType());
1997*09467b48Spatrick       BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext);
1998*09467b48Spatrick       Add->setHasNoSignedWrap(I.hasNoSignedWrap());
1999*09467b48Spatrick       return Add;
2000*09467b48Spatrick     }
2001*09467b48Spatrick     // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y)
2002*09467b48Spatrick     // 'nuw' is dropped in favor of the canonical form.
2003*09467b48Spatrick     if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) {
2004*09467b48Spatrick       Value *Sext = Builder.CreateSExt(Y, I.getType());
2005*09467b48Spatrick       BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext);
2006*09467b48Spatrick       Add->setHasNoSignedWrap(I.hasNoSignedWrap());
2007*09467b48Spatrick       return Add;
2008*09467b48Spatrick     }
2009*09467b48Spatrick 
2010*09467b48Spatrick     // X - A*-B -> X + A*B
2011*09467b48Spatrick     // X - -A*B -> X + A*B
2012*09467b48Spatrick     Value *A, *B;
2013*09467b48Spatrick     if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B)))))
2014*09467b48Spatrick       return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B));
2015*09467b48Spatrick 
2016*09467b48Spatrick     // X - A*C -> X + A*-C
2017*09467b48Spatrick     // No need to handle commuted multiply because multiply handling will
2018*09467b48Spatrick     // ensure constant will be move to the right hand side.
2019*09467b48Spatrick     if (match(Op1, m_Mul(m_Value(A), m_Constant(C))) && !isa<ConstantExpr>(C)) {
2020*09467b48Spatrick       Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(C));
2021*09467b48Spatrick       return BinaryOperator::CreateAdd(Op0, NewMul);
2022*09467b48Spatrick     }
2023*09467b48Spatrick   }
2024*09467b48Spatrick 
2025*09467b48Spatrick   {
2026*09467b48Spatrick     // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A
2027*09467b48Spatrick     // ~A - Min/Max(O, ~A) -> Max/Min(A, ~O) - A
2028*09467b48Spatrick     // Min/Max(~A, O) - ~A -> A - Max/Min(A, ~O)
2029*09467b48Spatrick     // Min/Max(O, ~A) - ~A -> A - Max/Min(A, ~O)
2030*09467b48Spatrick     // So long as O here is freely invertible, this will be neutral or a win.
2031*09467b48Spatrick     Value *LHS, *RHS, *A;
2032*09467b48Spatrick     Value *NotA = Op0, *MinMax = Op1;
2033*09467b48Spatrick     SelectPatternFlavor SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor;
2034*09467b48Spatrick     if (!SelectPatternResult::isMinOrMax(SPF)) {
2035*09467b48Spatrick       NotA = Op1;
2036*09467b48Spatrick       MinMax = Op0;
2037*09467b48Spatrick       SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor;
2038*09467b48Spatrick     }
2039*09467b48Spatrick     if (SelectPatternResult::isMinOrMax(SPF) &&
2040*09467b48Spatrick         match(NotA, m_Not(m_Value(A))) && (NotA == LHS || NotA == RHS)) {
2041*09467b48Spatrick       if (NotA == LHS)
2042*09467b48Spatrick         std::swap(LHS, RHS);
2043*09467b48Spatrick       // LHS is now O above and expected to have at least 2 uses (the min/max)
2044*09467b48Spatrick       // NotA is epected to have 2 uses from the min/max and 1 from the sub.
2045*09467b48Spatrick       if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
2046*09467b48Spatrick           !NotA->hasNUsesOrMore(4)) {
2047*09467b48Spatrick         // Note: We don't generate the inverse max/min, just create the not of
2048*09467b48Spatrick         // it and let other folds do the rest.
2049*09467b48Spatrick         Value *Not = Builder.CreateNot(MinMax);
2050*09467b48Spatrick         if (NotA == Op0)
2051*09467b48Spatrick           return BinaryOperator::CreateSub(Not, A);
2052*09467b48Spatrick         else
2053*09467b48Spatrick           return BinaryOperator::CreateSub(A, Not);
2054*09467b48Spatrick       }
2055*09467b48Spatrick     }
2056*09467b48Spatrick   }
2057*09467b48Spatrick 
2058*09467b48Spatrick   // Optimize pointer differences into the same array into a size.  Consider:
2059*09467b48Spatrick   //  &A[10] - &A[0]: we should compile this to "10".
2060*09467b48Spatrick   Value *LHSOp, *RHSOp;
2061*09467b48Spatrick   if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
2062*09467b48Spatrick       match(Op1, m_PtrToInt(m_Value(RHSOp))))
2063*09467b48Spatrick     if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(),
2064*09467b48Spatrick                                                I.hasNoUnsignedWrap()))
2065*09467b48Spatrick       return replaceInstUsesWith(I, Res);
2066*09467b48Spatrick 
2067*09467b48Spatrick   // trunc(p)-trunc(q) -> trunc(p-q)
2068*09467b48Spatrick   if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
2069*09467b48Spatrick       match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
2070*09467b48Spatrick     if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(),
2071*09467b48Spatrick                                                /* IsNUW */ false))
2072*09467b48Spatrick       return replaceInstUsesWith(I, Res);
2073*09467b48Spatrick 
2074*09467b48Spatrick   // Canonicalize a shifty way to code absolute value to the common pattern.
2075*09467b48Spatrick   // There are 2 potential commuted variants.
2076*09467b48Spatrick   // We're relying on the fact that we only do this transform when the shift has
2077*09467b48Spatrick   // exactly 2 uses and the xor has exactly 1 use (otherwise, we might increase
2078*09467b48Spatrick   // instructions).
2079*09467b48Spatrick   Value *A;
2080*09467b48Spatrick   const APInt *ShAmt;
2081*09467b48Spatrick   Type *Ty = I.getType();
2082*09467b48Spatrick   if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
2083*09467b48Spatrick       Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
2084*09467b48Spatrick       match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) {
2085*09467b48Spatrick     // B = ashr i32 A, 31 ; smear the sign bit
2086*09467b48Spatrick     // sub (xor A, B), B  ; flip bits if negative and subtract -1 (add 1)
2087*09467b48Spatrick     // --> (A < 0) ? -A : A
2088*09467b48Spatrick     Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty));
2089*09467b48Spatrick     // Copy the nuw/nsw flags from the sub to the negate.
2090*09467b48Spatrick     Value *Neg = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(),
2091*09467b48Spatrick                                    I.hasNoSignedWrap());
2092*09467b48Spatrick     return SelectInst::Create(Cmp, Neg, A);
2093*09467b48Spatrick   }
2094*09467b48Spatrick 
2095*09467b48Spatrick   if (Instruction *V =
2096*09467b48Spatrick           canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
2097*09467b48Spatrick     return V;
2098*09467b48Spatrick 
2099*09467b48Spatrick   if (Instruction *Ext = narrowMathIfNoOverflow(I))
2100*09467b48Spatrick     return Ext;
2101*09467b48Spatrick 
2102*09467b48Spatrick   bool Changed = false;
2103*09467b48Spatrick   if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) {
2104*09467b48Spatrick     Changed = true;
2105*09467b48Spatrick     I.setHasNoSignedWrap(true);
2106*09467b48Spatrick   }
2107*09467b48Spatrick   if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) {
2108*09467b48Spatrick     Changed = true;
2109*09467b48Spatrick     I.setHasNoUnsignedWrap(true);
2110*09467b48Spatrick   }
2111*09467b48Spatrick 
2112*09467b48Spatrick   return Changed ? &I : nullptr;
2113*09467b48Spatrick }
2114*09467b48Spatrick 
2115*09467b48Spatrick /// This eliminates floating-point negation in either 'fneg(X)' or
2116*09467b48Spatrick /// 'fsub(-0.0, X)' form by combining into a constant operand.
2117*09467b48Spatrick static Instruction *foldFNegIntoConstant(Instruction &I) {
2118*09467b48Spatrick   Value *X;
2119*09467b48Spatrick   Constant *C;
2120*09467b48Spatrick 
2121*09467b48Spatrick   // Fold negation into constant operand. This is limited with one-use because
2122*09467b48Spatrick   // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv.
2123*09467b48Spatrick   // -(X * C) --> X * (-C)
2124*09467b48Spatrick   // FIXME: It's arguable whether these should be m_OneUse or not. The current
2125*09467b48Spatrick   // belief is that the FNeg allows for better reassociation opportunities.
2126*09467b48Spatrick   if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C))))))
2127*09467b48Spatrick     return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
2128*09467b48Spatrick   // -(X / C) --> X / (-C)
2129*09467b48Spatrick   if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C))))))
2130*09467b48Spatrick     return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I);
2131*09467b48Spatrick   // -(C / X) --> (-C) / X
2132*09467b48Spatrick   if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X))))))
2133*09467b48Spatrick     return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I);
2134*09467b48Spatrick 
2135*09467b48Spatrick   return nullptr;
2136*09467b48Spatrick }
2137*09467b48Spatrick 
2138*09467b48Spatrick static Instruction *hoistFNegAboveFMulFDiv(Instruction &I,
2139*09467b48Spatrick                                            InstCombiner::BuilderTy &Builder) {
2140*09467b48Spatrick   Value *FNeg;
2141*09467b48Spatrick   if (!match(&I, m_FNeg(m_Value(FNeg))))
2142*09467b48Spatrick     return nullptr;
2143*09467b48Spatrick 
2144*09467b48Spatrick   Value *X, *Y;
2145*09467b48Spatrick   if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))
2146*09467b48Spatrick     return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
2147*09467b48Spatrick 
2148*09467b48Spatrick   if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))))
2149*09467b48Spatrick     return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
2150*09467b48Spatrick 
2151*09467b48Spatrick   return nullptr;
2152*09467b48Spatrick }
2153*09467b48Spatrick 
2154*09467b48Spatrick Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
2155*09467b48Spatrick   Value *Op = I.getOperand(0);
2156*09467b48Spatrick 
2157*09467b48Spatrick   if (Value *V = SimplifyFNegInst(Op, I.getFastMathFlags(),
2158*09467b48Spatrick                                   SQ.getWithInstruction(&I)))
2159*09467b48Spatrick     return replaceInstUsesWith(I, V);
2160*09467b48Spatrick 
2161*09467b48Spatrick   if (Instruction *X = foldFNegIntoConstant(I))
2162*09467b48Spatrick     return X;
2163*09467b48Spatrick 
2164*09467b48Spatrick   Value *X, *Y;
2165*09467b48Spatrick 
2166*09467b48Spatrick   // If we can ignore the sign of zeros: -(X - Y) --> (Y - X)
2167*09467b48Spatrick   if (I.hasNoSignedZeros() &&
2168*09467b48Spatrick       match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y)))))
2169*09467b48Spatrick     return BinaryOperator::CreateFSubFMF(Y, X, &I);
2170*09467b48Spatrick 
2171*09467b48Spatrick   if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
2172*09467b48Spatrick     return R;
2173*09467b48Spatrick 
2174*09467b48Spatrick   return nullptr;
2175*09467b48Spatrick }
2176*09467b48Spatrick 
2177*09467b48Spatrick Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
2178*09467b48Spatrick   if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1),
2179*09467b48Spatrick                                   I.getFastMathFlags(),
2180*09467b48Spatrick                                   SQ.getWithInstruction(&I)))
2181*09467b48Spatrick     return replaceInstUsesWith(I, V);
2182*09467b48Spatrick 
2183*09467b48Spatrick   if (Instruction *X = foldVectorBinop(I))
2184*09467b48Spatrick     return X;
2185*09467b48Spatrick 
2186*09467b48Spatrick   // Subtraction from -0.0 is the canonical form of fneg.
2187*09467b48Spatrick   // fsub nsz 0, X ==> fsub nsz -0.0, X
2188*09467b48Spatrick   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
2189*09467b48Spatrick   if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP()))
2190*09467b48Spatrick     return BinaryOperator::CreateFNegFMF(Op1, &I);
2191*09467b48Spatrick 
2192*09467b48Spatrick   if (Instruction *X = foldFNegIntoConstant(I))
2193*09467b48Spatrick     return X;
2194*09467b48Spatrick 
2195*09467b48Spatrick   if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
2196*09467b48Spatrick     return R;
2197*09467b48Spatrick 
2198*09467b48Spatrick   Value *X, *Y;
2199*09467b48Spatrick   Constant *C;
2200*09467b48Spatrick 
2201*09467b48Spatrick   // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X)
2202*09467b48Spatrick   // Canonicalize to fadd to make analysis easier.
2203*09467b48Spatrick   // This can also help codegen because fadd is commutative.
2204*09467b48Spatrick   // Note that if this fsub was really an fneg, the fadd with -0.0 will get
2205*09467b48Spatrick   // killed later. We still limit that particular transform with 'hasOneUse'
2206*09467b48Spatrick   // because an fneg is assumed better/cheaper than a generic fsub.
2207*09467b48Spatrick   if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) {
2208*09467b48Spatrick     if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) {
2209*09467b48Spatrick       Value *NewSub = Builder.CreateFSubFMF(Y, X, &I);
2210*09467b48Spatrick       return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I);
2211*09467b48Spatrick     }
2212*09467b48Spatrick   }
2213*09467b48Spatrick 
2214*09467b48Spatrick   if (isa<Constant>(Op0))
2215*09467b48Spatrick     if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
2216*09467b48Spatrick       if (Instruction *NV = FoldOpIntoSelect(I, SI))
2217*09467b48Spatrick         return NV;
2218*09467b48Spatrick 
2219*09467b48Spatrick   // X - C --> X + (-C)
2220*09467b48Spatrick   // But don't transform constant expressions because there's an inverse fold
2221*09467b48Spatrick   // for X + (-Y) --> X - Y.
2222*09467b48Spatrick   if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1))
2223*09467b48Spatrick     return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I);
2224*09467b48Spatrick 
2225*09467b48Spatrick   // X - (-Y) --> X + Y
2226*09467b48Spatrick   if (match(Op1, m_FNeg(m_Value(Y))))
2227*09467b48Spatrick     return BinaryOperator::CreateFAddFMF(Op0, Y, &I);
2228*09467b48Spatrick 
2229*09467b48Spatrick   // Similar to above, but look through a cast of the negated value:
2230*09467b48Spatrick   // X - (fptrunc(-Y)) --> X + fptrunc(Y)
2231*09467b48Spatrick   Type *Ty = I.getType();
2232*09467b48Spatrick   if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y))))))
2233*09467b48Spatrick     return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I);
2234*09467b48Spatrick 
2235*09467b48Spatrick   // X - (fpext(-Y)) --> X + fpext(Y)
2236*09467b48Spatrick   if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y))))))
2237*09467b48Spatrick     return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I);
2238*09467b48Spatrick 
2239*09467b48Spatrick   // Similar to above, but look through fmul/fdiv of the negated value:
2240*09467b48Spatrick   // Op0 - (-X * Y) --> Op0 + (X * Y)
2241*09467b48Spatrick   // Op0 - (Y * -X) --> Op0 + (X * Y)
2242*09467b48Spatrick   if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) {
2243*09467b48Spatrick     Value *FMul = Builder.CreateFMulFMF(X, Y, &I);
2244*09467b48Spatrick     return BinaryOperator::CreateFAddFMF(Op0, FMul, &I);
2245*09467b48Spatrick   }
2246*09467b48Spatrick   // Op0 - (-X / Y) --> Op0 + (X / Y)
2247*09467b48Spatrick   // Op0 - (X / -Y) --> Op0 + (X / Y)
2248*09467b48Spatrick   if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) ||
2249*09467b48Spatrick       match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) {
2250*09467b48Spatrick     Value *FDiv = Builder.CreateFDivFMF(X, Y, &I);
2251*09467b48Spatrick     return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I);
2252*09467b48Spatrick   }
2253*09467b48Spatrick 
2254*09467b48Spatrick   // Handle special cases for FSub with selects feeding the operation
2255*09467b48Spatrick   if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
2256*09467b48Spatrick     return replaceInstUsesWith(I, V);
2257*09467b48Spatrick 
2258*09467b48Spatrick   if (I.hasAllowReassoc() && I.hasNoSignedZeros()) {
2259*09467b48Spatrick     // (Y - X) - Y --> -X
2260*09467b48Spatrick     if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X))))
2261*09467b48Spatrick       return BinaryOperator::CreateFNegFMF(X, &I);
2262*09467b48Spatrick 
2263*09467b48Spatrick     // Y - (X + Y) --> -X
2264*09467b48Spatrick     // Y - (Y + X) --> -X
2265*09467b48Spatrick     if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X))))
2266*09467b48Spatrick       return BinaryOperator::CreateFNegFMF(X, &I);
2267*09467b48Spatrick 
2268*09467b48Spatrick     // (X * C) - X --> X * (C - 1.0)
2269*09467b48Spatrick     if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) {
2270*09467b48Spatrick       Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0));
2271*09467b48Spatrick       return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I);
2272*09467b48Spatrick     }
2273*09467b48Spatrick     // X - (X * C) --> X * (1.0 - C)
2274*09467b48Spatrick     if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) {
2275*09467b48Spatrick       Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C);
2276*09467b48Spatrick       return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I);
2277*09467b48Spatrick     }
2278*09467b48Spatrick 
2279*09467b48Spatrick     if (Instruction *F = factorizeFAddFSub(I, Builder))
2280*09467b48Spatrick       return F;
2281*09467b48Spatrick 
2282*09467b48Spatrick     // TODO: This performs reassociative folds for FP ops. Some fraction of the
2283*09467b48Spatrick     // functionality has been subsumed by simple pattern matching here and in
2284*09467b48Spatrick     // InstSimplify. We should let a dedicated reassociation pass handle more
2285*09467b48Spatrick     // complex pattern matching and remove this from InstCombine.
2286*09467b48Spatrick     if (Value *V = FAddCombine(Builder).simplify(&I))
2287*09467b48Spatrick       return replaceInstUsesWith(I, V);
2288*09467b48Spatrick   }
2289*09467b48Spatrick 
2290*09467b48Spatrick   return nullptr;
2291*09467b48Spatrick }
2292