xref: /llvm-project/llvm/lib/CodeGen/ExpandReductions.cpp (revision 2946cd701067404b99c39fb29dc9c74bd7193eb3)
1 //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements IR expansion for reduction intrinsics, allowing targets
10 // to enable the experimental intrinsics until just before codegen.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/ExpandReductions.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstIterator.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Pass.h"
24 #include "llvm/Transforms/Utils/LoopUtils.h"
25 
26 using namespace llvm;
27 
28 namespace {
29 
30 unsigned getOpcode(Intrinsic::ID ID) {
31   switch (ID) {
32   case Intrinsic::experimental_vector_reduce_fadd:
33     return Instruction::FAdd;
34   case Intrinsic::experimental_vector_reduce_fmul:
35     return Instruction::FMul;
36   case Intrinsic::experimental_vector_reduce_add:
37     return Instruction::Add;
38   case Intrinsic::experimental_vector_reduce_mul:
39     return Instruction::Mul;
40   case Intrinsic::experimental_vector_reduce_and:
41     return Instruction::And;
42   case Intrinsic::experimental_vector_reduce_or:
43     return Instruction::Or;
44   case Intrinsic::experimental_vector_reduce_xor:
45     return Instruction::Xor;
46   case Intrinsic::experimental_vector_reduce_smax:
47   case Intrinsic::experimental_vector_reduce_smin:
48   case Intrinsic::experimental_vector_reduce_umax:
49   case Intrinsic::experimental_vector_reduce_umin:
50     return Instruction::ICmp;
51   case Intrinsic::experimental_vector_reduce_fmax:
52   case Intrinsic::experimental_vector_reduce_fmin:
53     return Instruction::FCmp;
54   default:
55     llvm_unreachable("Unexpected ID");
56   }
57 }
58 
59 RecurrenceDescriptor::MinMaxRecurrenceKind getMRK(Intrinsic::ID ID) {
60   switch (ID) {
61   case Intrinsic::experimental_vector_reduce_smax:
62     return RecurrenceDescriptor::MRK_SIntMax;
63   case Intrinsic::experimental_vector_reduce_smin:
64     return RecurrenceDescriptor::MRK_SIntMin;
65   case Intrinsic::experimental_vector_reduce_umax:
66     return RecurrenceDescriptor::MRK_UIntMax;
67   case Intrinsic::experimental_vector_reduce_umin:
68     return RecurrenceDescriptor::MRK_UIntMin;
69   case Intrinsic::experimental_vector_reduce_fmax:
70     return RecurrenceDescriptor::MRK_FloatMax;
71   case Intrinsic::experimental_vector_reduce_fmin:
72     return RecurrenceDescriptor::MRK_FloatMin;
73   default:
74     return RecurrenceDescriptor::MRK_Invalid;
75   }
76 }
77 
78 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
79   bool Changed = false;
80   SmallVector<IntrinsicInst *, 4> Worklist;
81   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
82     if (auto II = dyn_cast<IntrinsicInst>(&*I))
83       Worklist.push_back(II);
84 
85   for (auto *II : Worklist) {
86     IRBuilder<> Builder(II);
87     bool IsOrdered = false;
88     Value *Acc = nullptr;
89     Value *Vec = nullptr;
90     auto ID = II->getIntrinsicID();
91     auto MRK = RecurrenceDescriptor::MRK_Invalid;
92     switch (ID) {
93     case Intrinsic::experimental_vector_reduce_fadd:
94     case Intrinsic::experimental_vector_reduce_fmul:
95       // FMFs must be attached to the call, otherwise it's an ordered reduction
96       // and it can't be handled by generating a shuffle sequence.
97       if (!II->getFastMathFlags().isFast())
98         IsOrdered = true;
99       Acc = II->getArgOperand(0);
100       Vec = II->getArgOperand(1);
101       break;
102     case Intrinsic::experimental_vector_reduce_add:
103     case Intrinsic::experimental_vector_reduce_mul:
104     case Intrinsic::experimental_vector_reduce_and:
105     case Intrinsic::experimental_vector_reduce_or:
106     case Intrinsic::experimental_vector_reduce_xor:
107     case Intrinsic::experimental_vector_reduce_smax:
108     case Intrinsic::experimental_vector_reduce_smin:
109     case Intrinsic::experimental_vector_reduce_umax:
110     case Intrinsic::experimental_vector_reduce_umin:
111     case Intrinsic::experimental_vector_reduce_fmax:
112     case Intrinsic::experimental_vector_reduce_fmin:
113       Vec = II->getArgOperand(0);
114       MRK = getMRK(ID);
115       break;
116     default:
117       continue;
118     }
119     if (!TTI->shouldExpandReduction(II))
120       continue;
121     Value *Rdx =
122         IsOrdered ? getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), MRK)
123                   : getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
124     II->replaceAllUsesWith(Rdx);
125     II->eraseFromParent();
126     Changed = true;
127   }
128   return Changed;
129 }
130 
131 class ExpandReductions : public FunctionPass {
132 public:
133   static char ID;
134   ExpandReductions() : FunctionPass(ID) {
135     initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
136   }
137 
138   bool runOnFunction(Function &F) override {
139     const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
140     return expandReductions(F, TTI);
141   }
142 
143   void getAnalysisUsage(AnalysisUsage &AU) const override {
144     AU.addRequired<TargetTransformInfoWrapperPass>();
145     AU.setPreservesCFG();
146   }
147 };
148 }
149 
150 char ExpandReductions::ID;
151 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
152                       "Expand reduction intrinsics", false, false)
153 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
154 INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
155                     "Expand reduction intrinsics", false, false)
156 
157 FunctionPass *llvm::createExpandReductionsPass() {
158   return new ExpandReductions();
159 }
160 
161 PreservedAnalyses ExpandReductionsPass::run(Function &F,
162                                             FunctionAnalysisManager &AM) {
163   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
164   if (!expandReductions(F, &TTI))
165     return PreservedAnalyses::all();
166   PreservedAnalyses PA;
167   PA.preserveSet<CFGAnalyses>();
168   return PA;
169 }
170