196605639SArthur Eubanks //===----------------------------------------------------------------------===//
2f18c0739SSamuel //
3f18c0739SSamuel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f18c0739SSamuel // See https://llvm.org/LICENSE.txt for license information.
5f18c0739SSamuel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f18c0739SSamuel //
7f18c0739SSamuel //===----------------------------------------------------------------------===//
8f18c0739SSamuel
9f18c0739SSamuel #include "ReduceOperands.h"
10f18c0739SSamuel #include "llvm/IR/Constants.h"
11f18c0739SSamuel #include "llvm/IR/InstIterator.h"
12be0b47d5SArthur Eubanks #include "llvm/IR/InstrTypes.h"
1396605639SArthur Eubanks #include "llvm/IR/Operator.h"
146b8bd0f7SMatt Arsenault #include "llvm/IR/PatternMatch.h"
1596605639SArthur Eubanks #include "llvm/IR/Type.h"
16f18c0739SSamuel
17f18c0739SSamuel using namespace llvm;
186b8bd0f7SMatt Arsenault using namespace PatternMatch;
19f18c0739SSamuel
2096605639SArthur Eubanks static void
extractOperandsFromModule(Oracle & O,ReducerWorkItem & WorkItem,function_ref<Value * (Use &)> ReduceValue)21*23cc36e4SMatt Arsenault extractOperandsFromModule(Oracle &O, ReducerWorkItem &WorkItem,
2296605639SArthur Eubanks function_ref<Value *(Use &)> ReduceValue) {
23*23cc36e4SMatt Arsenault Module &Program = WorkItem.getModule();
24*23cc36e4SMatt Arsenault
2577bc3ba3SArthur Eubanks for (auto &F : Program.functions()) {
26f18c0739SSamuel for (auto &I : instructions(&F)) {
270a159427SMatt Arsenault if (PHINode *Phi = dyn_cast<PHINode>(&I)) {
280a159427SMatt Arsenault for (auto &Op : Phi->incoming_values()) {
290a159427SMatt Arsenault if (!O.shouldKeep()) {
300a159427SMatt Arsenault if (Value *Reduced = ReduceValue(Op))
310a159427SMatt Arsenault Phi->setIncomingValueForBlock(Phi->getIncomingBlock(Op), Reduced);
320a159427SMatt Arsenault }
330a159427SMatt Arsenault }
340a159427SMatt Arsenault
350a159427SMatt Arsenault continue;
360a159427SMatt Arsenault }
370a159427SMatt Arsenault
38f18c0739SSamuel for (auto &Op : I.operands()) {
39ce3c3cb2SArthur Eubanks if (Value *Reduced = ReduceValue(Op)) {
40ce3c3cb2SArthur Eubanks if (!O.shouldKeep())
4196605639SArthur Eubanks Op.set(Reduced);
42f18c0739SSamuel }
43f18c0739SSamuel }
44f18c0739SSamuel }
45f18c0739SSamuel }
4662b5aa98SMatt Arsenault }
47f18c0739SSamuel
isOne(Use & Op)4896605639SArthur Eubanks static bool isOne(Use &Op) {
4996605639SArthur Eubanks auto *C = dyn_cast<Constant>(Op);
5096605639SArthur Eubanks return C && C->isOneValue();
5196605639SArthur Eubanks }
5296605639SArthur Eubanks
isZero(Use & Op)5396605639SArthur Eubanks static bool isZero(Use &Op) {
5496605639SArthur Eubanks auto *C = dyn_cast<Constant>(Op);
5596605639SArthur Eubanks return C && C->isNullValue();
5696605639SArthur Eubanks }
5796605639SArthur Eubanks
isZeroOrOneFP(Value * Op)586b8bd0f7SMatt Arsenault static bool isZeroOrOneFP(Value *Op) {
596b8bd0f7SMatt Arsenault const APFloat *C;
606b8bd0f7SMatt Arsenault return match(Op, m_APFloat(C)) &&
616b8bd0f7SMatt Arsenault ((C->isZero() && !C->isNegative()) || C->isExactlyValue(1.0));
626b8bd0f7SMatt Arsenault }
636b8bd0f7SMatt Arsenault
shouldReduceOperand(Use & Op)64be0b47d5SArthur Eubanks static bool shouldReduceOperand(Use &Op) {
65be0b47d5SArthur Eubanks Type *Ty = Op->getType();
66be0b47d5SArthur Eubanks if (Ty->isLabelTy() || Ty->isMetadataTy())
67be0b47d5SArthur Eubanks return false;
68be0b47d5SArthur Eubanks // TODO: be more precise about which GEP operands we can reduce (e.g. array
69be0b47d5SArthur Eubanks // indexes)
70be0b47d5SArthur Eubanks if (isa<GEPOperator>(Op.getUser()))
71be0b47d5SArthur Eubanks return false;
72be0b47d5SArthur Eubanks if (auto *CB = dyn_cast<CallBase>(Op.getUser())) {
73be0b47d5SArthur Eubanks if (&CB->getCalledOperandUse() == &Op)
74be0b47d5SArthur Eubanks return false;
75be0b47d5SArthur Eubanks }
76be0b47d5SArthur Eubanks return true;
77be0b47d5SArthur Eubanks }
78be0b47d5SArthur Eubanks
switchCaseExists(Use & Op,ConstantInt * CI)795b4f6d8bSJohn Regehr static bool switchCaseExists(Use &Op, ConstantInt *CI) {
805b4f6d8bSJohn Regehr SwitchInst *SI = dyn_cast<SwitchInst>(Op.getUser());
815b4f6d8bSJohn Regehr if (!SI)
825b4f6d8bSJohn Regehr return false;
835b4f6d8bSJohn Regehr return SI->findCaseValue(CI) != SI->case_default();
845b4f6d8bSJohn Regehr }
855b4f6d8bSJohn Regehr
reduceOperandsOneDeltaPass(TestRunner & Test)8696605639SArthur Eubanks void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) {
8796605639SArthur Eubanks auto ReduceValue = [](Use &Op) -> Value * {
88be0b47d5SArthur Eubanks if (!shouldReduceOperand(Op))
8996605639SArthur Eubanks return nullptr;
906b8bd0f7SMatt Arsenault
916b8bd0f7SMatt Arsenault Type *Ty = Op->getType();
926b8bd0f7SMatt Arsenault if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
935b4f6d8bSJohn Regehr // Don't duplicate an existing switch case.
945b4f6d8bSJohn Regehr if (switchCaseExists(Op, ConstantInt::get(IntTy, 1)))
955b4f6d8bSJohn Regehr return nullptr;
9696605639SArthur Eubanks // Don't replace existing ones and zeroes.
976b8bd0f7SMatt Arsenault return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(IntTy, 1);
986b8bd0f7SMatt Arsenault }
996b8bd0f7SMatt Arsenault
1006b8bd0f7SMatt Arsenault if (Ty->isFloatingPointTy())
1016b8bd0f7SMatt Arsenault return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, 1.0);
1026b8bd0f7SMatt Arsenault
1036b8bd0f7SMatt Arsenault if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
104bb3f99cdSFraser Cormack if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
1056b8bd0f7SMatt Arsenault return nullptr;
1066b8bd0f7SMatt Arsenault
107a946eb16SMatthias Braun Type *ElementType = VT->getElementType();
108a946eb16SMatthias Braun Constant *C;
109a946eb16SMatthias Braun if (ElementType->isFloatingPointTy()) {
110a946eb16SMatthias Braun C = ConstantFP::get(ElementType, 1.0);
111a946eb16SMatthias Braun } else if (IntegerType *IntTy = dyn_cast<IntegerType>(ElementType)) {
112a946eb16SMatthias Braun C = ConstantInt::get(IntTy, 1);
113a946eb16SMatthias Braun } else {
114a946eb16SMatthias Braun return nullptr;
115a946eb16SMatthias Braun }
116a946eb16SMatthias Braun return ConstantVector::getSplat(VT->getElementCount(), C);
1176b8bd0f7SMatt Arsenault }
1186b8bd0f7SMatt Arsenault
1196b8bd0f7SMatt Arsenault return nullptr;
12096605639SArthur Eubanks };
1212592ccdeSArthur Eubanks runDeltaPass(
1222592ccdeSArthur Eubanks Test,
123*23cc36e4SMatt Arsenault [ReduceValue](Oracle &O, ReducerWorkItem &WorkItem) {
124*23cc36e4SMatt Arsenault extractOperandsFromModule(O, WorkItem, ReduceValue);
1252592ccdeSArthur Eubanks },
1262592ccdeSArthur Eubanks "Reducing Operands to one");
12796605639SArthur Eubanks }
12896605639SArthur Eubanks
reduceOperandsZeroDeltaPass(TestRunner & Test)12996605639SArthur Eubanks void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) {
13096605639SArthur Eubanks auto ReduceValue = [](Use &Op) -> Value * {
131be0b47d5SArthur Eubanks if (!shouldReduceOperand(Op))
13296605639SArthur Eubanks return nullptr;
1335b4f6d8bSJohn Regehr // Don't duplicate an existing switch case.
1345b4f6d8bSJohn Regehr if (auto *IntTy = dyn_cast<IntegerType>(Op->getType()))
1355b4f6d8bSJohn Regehr if (switchCaseExists(Op, ConstantInt::get(IntTy, 0)))
1365b4f6d8bSJohn Regehr return nullptr;
13796605639SArthur Eubanks // Don't replace existing zeroes.
13896605639SArthur Eubanks return isZero(Op) ? nullptr : Constant::getNullValue(Op->getType());
13996605639SArthur Eubanks };
1402592ccdeSArthur Eubanks runDeltaPass(
1412592ccdeSArthur Eubanks Test,
142*23cc36e4SMatt Arsenault [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
14396605639SArthur Eubanks extractOperandsFromModule(O, Program, ReduceValue);
1442592ccdeSArthur Eubanks },
1452592ccdeSArthur Eubanks "Reducing Operands to zero");
146f18c0739SSamuel }
14726107559SMatt Arsenault
reduceOperandsNaNDeltaPass(TestRunner & Test)14826107559SMatt Arsenault void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) {
14926107559SMatt Arsenault auto ReduceValue = [](Use &Op) -> Value * {
15026107559SMatt Arsenault Type *Ty = Op->getType();
15126107559SMatt Arsenault if (!Ty->isFPOrFPVectorTy())
15226107559SMatt Arsenault return nullptr;
15326107559SMatt Arsenault
15426107559SMatt Arsenault // Prefer 0.0 or 1.0 over NaN.
15526107559SMatt Arsenault //
15626107559SMatt Arsenault // TODO: Preferring NaN may make more sense because FP operations are more
15726107559SMatt Arsenault // universally foldable.
15826107559SMatt Arsenault if (match(Op.get(), m_NaN()) || isZeroOrOneFP(Op.get()))
15926107559SMatt Arsenault return nullptr;
16026107559SMatt Arsenault
16126107559SMatt Arsenault if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
16226107559SMatt Arsenault return ConstantVector::getSplat(VT->getElementCount(),
16326107559SMatt Arsenault ConstantFP::getQNaN(VT->getElementType()));
16426107559SMatt Arsenault }
16526107559SMatt Arsenault
16626107559SMatt Arsenault return ConstantFP::getQNaN(Ty);
16726107559SMatt Arsenault };
1682592ccdeSArthur Eubanks runDeltaPass(
1692592ccdeSArthur Eubanks Test,
170*23cc36e4SMatt Arsenault [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
17126107559SMatt Arsenault extractOperandsFromModule(O, Program, ReduceValue);
1722592ccdeSArthur Eubanks },
1732592ccdeSArthur Eubanks "Reducing Operands to NaN");
17426107559SMatt Arsenault }
175