xref: /llvm-project/llvm/tools/llvm-reduce/deltas/ReduceOperands.cpp (revision 23cc36e4765912a1bcdbbc3fb8b0976a06dea043)
196605639SArthur Eubanks //===----------------------------------------------------------------------===//
2f18c0739SSamuel //
3f18c0739SSamuel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f18c0739SSamuel // See https://llvm.org/LICENSE.txt for license information.
5f18c0739SSamuel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f18c0739SSamuel //
7f18c0739SSamuel //===----------------------------------------------------------------------===//
8f18c0739SSamuel 
9f18c0739SSamuel #include "ReduceOperands.h"
10f18c0739SSamuel #include "llvm/IR/Constants.h"
11f18c0739SSamuel #include "llvm/IR/InstIterator.h"
12be0b47d5SArthur Eubanks #include "llvm/IR/InstrTypes.h"
1396605639SArthur Eubanks #include "llvm/IR/Operator.h"
146b8bd0f7SMatt Arsenault #include "llvm/IR/PatternMatch.h"
1596605639SArthur Eubanks #include "llvm/IR/Type.h"
16f18c0739SSamuel 
17f18c0739SSamuel using namespace llvm;
186b8bd0f7SMatt Arsenault using namespace PatternMatch;
19f18c0739SSamuel 
2096605639SArthur Eubanks static void
extractOperandsFromModule(Oracle & O,ReducerWorkItem & WorkItem,function_ref<Value * (Use &)> ReduceValue)21*23cc36e4SMatt Arsenault extractOperandsFromModule(Oracle &O, ReducerWorkItem &WorkItem,
2296605639SArthur Eubanks                           function_ref<Value *(Use &)> ReduceValue) {
23*23cc36e4SMatt Arsenault   Module &Program = WorkItem.getModule();
24*23cc36e4SMatt Arsenault 
2577bc3ba3SArthur Eubanks   for (auto &F : Program.functions()) {
26f18c0739SSamuel     for (auto &I : instructions(&F)) {
270a159427SMatt Arsenault       if (PHINode *Phi = dyn_cast<PHINode>(&I)) {
280a159427SMatt Arsenault         for (auto &Op : Phi->incoming_values()) {
290a159427SMatt Arsenault           if (!O.shouldKeep()) {
300a159427SMatt Arsenault             if (Value *Reduced = ReduceValue(Op))
310a159427SMatt Arsenault               Phi->setIncomingValueForBlock(Phi->getIncomingBlock(Op), Reduced);
320a159427SMatt Arsenault           }
330a159427SMatt Arsenault         }
340a159427SMatt Arsenault 
350a159427SMatt Arsenault         continue;
360a159427SMatt Arsenault       }
370a159427SMatt Arsenault 
38f18c0739SSamuel       for (auto &Op : I.operands()) {
39ce3c3cb2SArthur Eubanks         if (Value *Reduced = ReduceValue(Op)) {
40ce3c3cb2SArthur Eubanks           if (!O.shouldKeep())
4196605639SArthur Eubanks             Op.set(Reduced);
42f18c0739SSamuel         }
43f18c0739SSamuel       }
44f18c0739SSamuel     }
45f18c0739SSamuel   }
4662b5aa98SMatt Arsenault }
47f18c0739SSamuel 
isOne(Use & Op)4896605639SArthur Eubanks static bool isOne(Use &Op) {
4996605639SArthur Eubanks   auto *C = dyn_cast<Constant>(Op);
5096605639SArthur Eubanks   return C && C->isOneValue();
5196605639SArthur Eubanks }
5296605639SArthur Eubanks 
isZero(Use & Op)5396605639SArthur Eubanks static bool isZero(Use &Op) {
5496605639SArthur Eubanks   auto *C = dyn_cast<Constant>(Op);
5596605639SArthur Eubanks   return C && C->isNullValue();
5696605639SArthur Eubanks }
5796605639SArthur Eubanks 
isZeroOrOneFP(Value * Op)586b8bd0f7SMatt Arsenault static bool isZeroOrOneFP(Value *Op) {
596b8bd0f7SMatt Arsenault   const APFloat *C;
606b8bd0f7SMatt Arsenault   return match(Op, m_APFloat(C)) &&
616b8bd0f7SMatt Arsenault          ((C->isZero() && !C->isNegative()) || C->isExactlyValue(1.0));
626b8bd0f7SMatt Arsenault }
636b8bd0f7SMatt Arsenault 
shouldReduceOperand(Use & Op)64be0b47d5SArthur Eubanks static bool shouldReduceOperand(Use &Op) {
65be0b47d5SArthur Eubanks   Type *Ty = Op->getType();
66be0b47d5SArthur Eubanks   if (Ty->isLabelTy() || Ty->isMetadataTy())
67be0b47d5SArthur Eubanks     return false;
68be0b47d5SArthur Eubanks   // TODO: be more precise about which GEP operands we can reduce (e.g. array
69be0b47d5SArthur Eubanks   // indexes)
70be0b47d5SArthur Eubanks   if (isa<GEPOperator>(Op.getUser()))
71be0b47d5SArthur Eubanks     return false;
72be0b47d5SArthur Eubanks   if (auto *CB = dyn_cast<CallBase>(Op.getUser())) {
73be0b47d5SArthur Eubanks     if (&CB->getCalledOperandUse() == &Op)
74be0b47d5SArthur Eubanks       return false;
75be0b47d5SArthur Eubanks   }
76be0b47d5SArthur Eubanks   return true;
77be0b47d5SArthur Eubanks }
78be0b47d5SArthur Eubanks 
switchCaseExists(Use & Op,ConstantInt * CI)795b4f6d8bSJohn Regehr static bool switchCaseExists(Use &Op, ConstantInt *CI) {
805b4f6d8bSJohn Regehr   SwitchInst *SI = dyn_cast<SwitchInst>(Op.getUser());
815b4f6d8bSJohn Regehr   if (!SI)
825b4f6d8bSJohn Regehr     return false;
835b4f6d8bSJohn Regehr   return SI->findCaseValue(CI) != SI->case_default();
845b4f6d8bSJohn Regehr }
855b4f6d8bSJohn Regehr 
reduceOperandsOneDeltaPass(TestRunner & Test)8696605639SArthur Eubanks void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) {
8796605639SArthur Eubanks   auto ReduceValue = [](Use &Op) -> Value * {
88be0b47d5SArthur Eubanks     if (!shouldReduceOperand(Op))
8996605639SArthur Eubanks       return nullptr;
906b8bd0f7SMatt Arsenault 
916b8bd0f7SMatt Arsenault     Type *Ty = Op->getType();
926b8bd0f7SMatt Arsenault     if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
935b4f6d8bSJohn Regehr       // Don't duplicate an existing switch case.
945b4f6d8bSJohn Regehr       if (switchCaseExists(Op, ConstantInt::get(IntTy, 1)))
955b4f6d8bSJohn Regehr         return nullptr;
9696605639SArthur Eubanks       // Don't replace existing ones and zeroes.
976b8bd0f7SMatt Arsenault       return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(IntTy, 1);
986b8bd0f7SMatt Arsenault     }
996b8bd0f7SMatt Arsenault 
1006b8bd0f7SMatt Arsenault     if (Ty->isFloatingPointTy())
1016b8bd0f7SMatt Arsenault       return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, 1.0);
1026b8bd0f7SMatt Arsenault 
1036b8bd0f7SMatt Arsenault     if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
104bb3f99cdSFraser Cormack       if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
1056b8bd0f7SMatt Arsenault         return nullptr;
1066b8bd0f7SMatt Arsenault 
107a946eb16SMatthias Braun       Type *ElementType = VT->getElementType();
108a946eb16SMatthias Braun       Constant *C;
109a946eb16SMatthias Braun       if (ElementType->isFloatingPointTy()) {
110a946eb16SMatthias Braun         C = ConstantFP::get(ElementType, 1.0);
111a946eb16SMatthias Braun       } else if (IntegerType *IntTy = dyn_cast<IntegerType>(ElementType)) {
112a946eb16SMatthias Braun         C = ConstantInt::get(IntTy, 1);
113a946eb16SMatthias Braun       } else {
114a946eb16SMatthias Braun         return nullptr;
115a946eb16SMatthias Braun       }
116a946eb16SMatthias Braun       return ConstantVector::getSplat(VT->getElementCount(), C);
1176b8bd0f7SMatt Arsenault     }
1186b8bd0f7SMatt Arsenault 
1196b8bd0f7SMatt Arsenault     return nullptr;
12096605639SArthur Eubanks   };
1212592ccdeSArthur Eubanks   runDeltaPass(
1222592ccdeSArthur Eubanks       Test,
123*23cc36e4SMatt Arsenault       [ReduceValue](Oracle &O, ReducerWorkItem &WorkItem) {
124*23cc36e4SMatt Arsenault         extractOperandsFromModule(O, WorkItem, ReduceValue);
1252592ccdeSArthur Eubanks       },
1262592ccdeSArthur Eubanks       "Reducing Operands to one");
12796605639SArthur Eubanks }
12896605639SArthur Eubanks 
reduceOperandsZeroDeltaPass(TestRunner & Test)12996605639SArthur Eubanks void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) {
13096605639SArthur Eubanks   auto ReduceValue = [](Use &Op) -> Value * {
131be0b47d5SArthur Eubanks     if (!shouldReduceOperand(Op))
13296605639SArthur Eubanks       return nullptr;
1335b4f6d8bSJohn Regehr     // Don't duplicate an existing switch case.
1345b4f6d8bSJohn Regehr     if (auto *IntTy = dyn_cast<IntegerType>(Op->getType()))
1355b4f6d8bSJohn Regehr       if (switchCaseExists(Op, ConstantInt::get(IntTy, 0)))
1365b4f6d8bSJohn Regehr         return nullptr;
13796605639SArthur Eubanks     // Don't replace existing zeroes.
13896605639SArthur Eubanks     return isZero(Op) ? nullptr : Constant::getNullValue(Op->getType());
13996605639SArthur Eubanks   };
1402592ccdeSArthur Eubanks   runDeltaPass(
1412592ccdeSArthur Eubanks       Test,
142*23cc36e4SMatt Arsenault       [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
14396605639SArthur Eubanks         extractOperandsFromModule(O, Program, ReduceValue);
1442592ccdeSArthur Eubanks       },
1452592ccdeSArthur Eubanks       "Reducing Operands to zero");
146f18c0739SSamuel }
14726107559SMatt Arsenault 
reduceOperandsNaNDeltaPass(TestRunner & Test)14826107559SMatt Arsenault void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) {
14926107559SMatt Arsenault   auto ReduceValue = [](Use &Op) -> Value * {
15026107559SMatt Arsenault     Type *Ty = Op->getType();
15126107559SMatt Arsenault     if (!Ty->isFPOrFPVectorTy())
15226107559SMatt Arsenault       return nullptr;
15326107559SMatt Arsenault 
15426107559SMatt Arsenault     // Prefer 0.0 or 1.0 over NaN.
15526107559SMatt Arsenault     //
15626107559SMatt Arsenault     // TODO: Preferring NaN may make more sense because FP operations are more
15726107559SMatt Arsenault     // universally foldable.
15826107559SMatt Arsenault     if (match(Op.get(), m_NaN()) || isZeroOrOneFP(Op.get()))
15926107559SMatt Arsenault       return nullptr;
16026107559SMatt Arsenault 
16126107559SMatt Arsenault     if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
16226107559SMatt Arsenault       return ConstantVector::getSplat(VT->getElementCount(),
16326107559SMatt Arsenault                                       ConstantFP::getQNaN(VT->getElementType()));
16426107559SMatt Arsenault     }
16526107559SMatt Arsenault 
16626107559SMatt Arsenault     return ConstantFP::getQNaN(Ty);
16726107559SMatt Arsenault   };
1682592ccdeSArthur Eubanks   runDeltaPass(
1692592ccdeSArthur Eubanks       Test,
170*23cc36e4SMatt Arsenault       [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
17126107559SMatt Arsenault         extractOperandsFromModule(O, Program, ReduceValue);
1722592ccdeSArthur Eubanks       },
1732592ccdeSArthur Eubanks       "Reducing Operands to NaN");
17426107559SMatt Arsenault }
175