xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision 7dffc96a54f90569d6226dd5713c80fc8f30c76f)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/SandboxIR/Function.h"
12 #include "llvm/SandboxIR/Instruction.h"
13 #include "llvm/SandboxIR/Module.h"
14 #include "llvm/SandboxIR/Utils.h"
15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
17 
18 namespace llvm::sandboxir {
19 
20 BottomUpVec::BottomUpVec(StringRef Pipeline)
21     : FunctionPass("bottom-up-vec"),
22       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
23 
24 // TODO: This is a temporary function that returns some seeds.
25 //       Replace this with SeedCollector's function when it lands.
26 static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) {
27   llvm::SmallVector<Value *, 4> Seeds;
28   for (auto &I : BB)
29     if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
30       Seeds.push_back(SI);
31   return Seeds;
32 }
33 
34 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
35                                           unsigned OpIdx) {
36   SmallVector<Value *, 4> Operands;
37   for (Value *BndlV : Bndl) {
38     auto *BndlI = cast<Instruction>(BndlV);
39     Operands.push_back(BndlI->getOperand(OpIdx));
40   }
41   return Operands;
42 }
43 
44 static BasicBlock::iterator
45 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
46   // TODO: Use the VecUtils function for getting the bottom instr once it lands.
47   auto *BotI = cast<Instruction>(
48       *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
49         return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
50       }));
51   // If Bndl contains Arguments or Constants, use the beginning of the BB.
52   return std::next(BotI->getIterator());
53 }
54 
55 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
56                                       ArrayRef<Value *> Operands) {
57   Change = true;
58   assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
59          "Expect Instructions!");
60   auto &Ctx = Bndl[0]->getContext();
61 
62   Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
63   auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
64 
65   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
66 
67   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
68   switch (Opcode) {
69   case Instruction::Opcode::ZExt:
70   case Instruction::Opcode::SExt:
71   case Instruction::Opcode::FPToUI:
72   case Instruction::Opcode::FPToSI:
73   case Instruction::Opcode::FPExt:
74   case Instruction::Opcode::PtrToInt:
75   case Instruction::Opcode::IntToPtr:
76   case Instruction::Opcode::SIToFP:
77   case Instruction::Opcode::UIToFP:
78   case Instruction::Opcode::Trunc:
79   case Instruction::Opcode::FPTrunc:
80   case Instruction::Opcode::BitCast: {
81     assert(Operands.size() == 1u && "Casts are unary!");
82     return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
83   }
84   case Instruction::Opcode::FCmp:
85   case Instruction::Opcode::ICmp: {
86     auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
87     assert(all_of(drop_begin(Bndl),
88                   [Pred](auto *SBV) {
89                     return cast<CmpInst>(SBV)->getPredicate() == Pred;
90                   }) &&
91            "Expected same predicate across bundle.");
92     return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
93                            "VCmp");
94   }
95   case Instruction::Opcode::Select: {
96     return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
97                               Ctx, "Vec");
98   }
99   case Instruction::Opcode::FNeg: {
100     auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
101     auto OpC = UOp0->getOpcode();
102     return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
103                                                 Ctx, "Vec");
104   }
105   case Instruction::Opcode::Add:
106   case Instruction::Opcode::FAdd:
107   case Instruction::Opcode::Sub:
108   case Instruction::Opcode::FSub:
109   case Instruction::Opcode::Mul:
110   case Instruction::Opcode::FMul:
111   case Instruction::Opcode::UDiv:
112   case Instruction::Opcode::SDiv:
113   case Instruction::Opcode::FDiv:
114   case Instruction::Opcode::URem:
115   case Instruction::Opcode::SRem:
116   case Instruction::Opcode::FRem:
117   case Instruction::Opcode::Shl:
118   case Instruction::Opcode::LShr:
119   case Instruction::Opcode::AShr:
120   case Instruction::Opcode::And:
121   case Instruction::Opcode::Or:
122   case Instruction::Opcode::Xor: {
123     auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
124     auto *LHS = Operands[0];
125     auto *RHS = Operands[1];
126     return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
127                                                  BinOp0, WhereIt, Ctx, "Vec");
128   }
129   case Instruction::Opcode::Load: {
130     auto *Ld0 = cast<LoadInst>(Bndl[0]);
131     Value *Ptr = Ld0->getPointerOperand();
132     return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
133   }
134   case Instruction::Opcode::Store: {
135     auto Align = cast<StoreInst>(Bndl[0])->getAlign();
136     Value *Val = Operands[0];
137     Value *Ptr = Operands[1];
138     return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
139   }
140   case Instruction::Opcode::Br:
141   case Instruction::Opcode::Ret:
142   case Instruction::Opcode::PHI:
143   case Instruction::Opcode::AddrSpaceCast:
144   case Instruction::Opcode::Call:
145   case Instruction::Opcode::GetElementPtr:
146     llvm_unreachable("Unimplemented");
147     break;
148   default:
149     llvm_unreachable("Unimplemented");
150     break;
151   }
152   llvm_unreachable("Missing switch case!");
153   // TODO: Propagate debug info.
154 }
155 
156 void BottomUpVec::tryEraseDeadInstrs() {
157   // Visiting the dead instructions bottom-to-top.
158   sort(DeadInstrCandidates,
159        [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
160   for (Instruction *I : reverse(DeadInstrCandidates)) {
161     if (I->hasNUses(0))
162       I->eraseFromParent();
163   }
164   DeadInstrCandidates.clear();
165 }
166 
167 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
168   Value *NewVec = nullptr;
169   const auto &LegalityRes = Legality->canVectorize(Bndl);
170   switch (LegalityRes.getSubclassID()) {
171   case LegalityResultID::Widen: {
172     auto *I = cast<Instruction>(Bndl[0]);
173     SmallVector<Value *, 2> VecOperands;
174     switch (I->getOpcode()) {
175     case Instruction::Opcode::Load:
176       // Don't recurse towards the pointer operand.
177       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
178       break;
179     case Instruction::Opcode::Store: {
180       // Don't recurse towards the pointer operand.
181       auto *VecOp = vectorizeRec(getOperand(Bndl, 0));
182       VecOperands.push_back(VecOp);
183       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
184       break;
185     }
186     default:
187       // Visit all operands.
188       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
189         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx));
190         VecOperands.push_back(VecOp);
191       }
192       break;
193     }
194     NewVec = createVectorInstr(Bndl, VecOperands);
195 
196     // Collect the original scalar instructions as they may be dead.
197     if (NewVec != nullptr) {
198       for (Value *V : Bndl)
199         DeadInstrCandidates.push_back(cast<Instruction>(V));
200     }
201     break;
202   }
203   case LegalityResultID::Pack: {
204     // TODO: Unimplemented
205     llvm_unreachable("Unimplemented");
206   }
207   }
208   return NewVec;
209 }
210 
211 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
212   DeadInstrCandidates.clear();
213   vectorizeRec(Bndl);
214   tryEraseDeadInstrs();
215   return Change;
216 }
217 
218 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
219   Legality = std::make_unique<LegalityAnalysis>(
220       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
221       F.getContext());
222   Change = false;
223   // TODO: Start from innermost BBs first
224   for (auto &BB : F) {
225     // TODO: Replace with proper SeedCollector function.
226     auto Seeds = collectSeeds(BB);
227     // TODO: Slice Seeds into smaller chunks.
228     // TODO: If vectorization succeeds, run the RegionPassManager on the
229     // resulting region.
230     if (Seeds.size() >= 2)
231       Change |= tryVectorize(Seeds);
232   }
233   return Change;
234 }
235 
236 } // namespace llvm::sandboxir
237