xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision 5942a99f8b7dd361c35eb1c9c32b2475dce2c0b2)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/SandboxIR/Function.h"
12 #include "llvm/SandboxIR/Instruction.h"
13 #include "llvm/SandboxIR/Module.h"
14 #include "llvm/SandboxIR/Utils.h"
15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
17 
18 namespace llvm::sandboxir {
19 
20 BottomUpVec::BottomUpVec(StringRef Pipeline)
21     : FunctionPass("bottom-up-vec"),
22       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
23 
24 // TODO: This is a temporary function that returns some seeds.
25 //       Replace this with SeedCollector's function when it lands.
26 static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) {
27   llvm::SmallVector<Value *, 4> Seeds;
28   for (auto &I : BB)
29     if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
30       Seeds.push_back(SI);
31   return Seeds;
32 }
33 
34 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
35                                           unsigned OpIdx) {
36   SmallVector<Value *, 4> Operands;
37   for (Value *BndlV : Bndl) {
38     auto *BndlI = cast<Instruction>(BndlV);
39     Operands.push_back(BndlI->getOperand(OpIdx));
40   }
41   return Operands;
42 }
43 
44 static BasicBlock::iterator
45 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
46   // TODO: Use the VecUtils function for getting the bottom instr once it lands.
47   auto *BotI = cast<Instruction>(
48       *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
49         return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
50       }));
51   // If Bndl contains Arguments or Constants, use the beginning of the BB.
52   return std::next(BotI->getIterator());
53 }
54 
55 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
56                                       ArrayRef<Value *> Operands) {
57   Change = true;
58   assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
59          "Expect Instructions!");
60   auto &Ctx = Bndl[0]->getContext();
61 
62   Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
63   auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
64 
65   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
66 
67   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
68   switch (Opcode) {
69   case Instruction::Opcode::ZExt:
70   case Instruction::Opcode::SExt:
71   case Instruction::Opcode::FPToUI:
72   case Instruction::Opcode::FPToSI:
73   case Instruction::Opcode::FPExt:
74   case Instruction::Opcode::PtrToInt:
75   case Instruction::Opcode::IntToPtr:
76   case Instruction::Opcode::SIToFP:
77   case Instruction::Opcode::UIToFP:
78   case Instruction::Opcode::Trunc:
79   case Instruction::Opcode::FPTrunc:
80   case Instruction::Opcode::BitCast: {
81     assert(Operands.size() == 1u && "Casts are unary!");
82     return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
83   }
84   case Instruction::Opcode::FCmp:
85   case Instruction::Opcode::ICmp: {
86     auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
87     assert(all_of(drop_begin(Bndl),
88                   [Pred](auto *SBV) {
89                     return cast<CmpInst>(SBV)->getPredicate() == Pred;
90                   }) &&
91            "Expected same predicate across bundle.");
92     return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
93                            "VCmp");
94   }
95   case Instruction::Opcode::Select: {
96     return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
97                               Ctx, "Vec");
98   }
99   case Instruction::Opcode::FNeg: {
100     auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
101     auto OpC = UOp0->getOpcode();
102     return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
103                                                 Ctx, "Vec");
104   }
105   case Instruction::Opcode::Add:
106   case Instruction::Opcode::FAdd:
107   case Instruction::Opcode::Sub:
108   case Instruction::Opcode::FSub:
109   case Instruction::Opcode::Mul:
110   case Instruction::Opcode::FMul:
111   case Instruction::Opcode::UDiv:
112   case Instruction::Opcode::SDiv:
113   case Instruction::Opcode::FDiv:
114   case Instruction::Opcode::URem:
115   case Instruction::Opcode::SRem:
116   case Instruction::Opcode::FRem:
117   case Instruction::Opcode::Shl:
118   case Instruction::Opcode::LShr:
119   case Instruction::Opcode::AShr:
120   case Instruction::Opcode::And:
121   case Instruction::Opcode::Or:
122   case Instruction::Opcode::Xor: {
123     auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
124     auto *LHS = Operands[0];
125     auto *RHS = Operands[1];
126     return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
127                                                  BinOp0, WhereIt, Ctx, "Vec");
128   }
129   case Instruction::Opcode::Load: {
130     auto *Ld0 = cast<LoadInst>(Bndl[0]);
131     Value *Ptr = Ld0->getPointerOperand();
132     return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
133   }
134   case Instruction::Opcode::Store: {
135     auto Align = cast<StoreInst>(Bndl[0])->getAlign();
136     Value *Val = Operands[0];
137     Value *Ptr = Operands[1];
138     return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
139   }
140   case Instruction::Opcode::Br:
141   case Instruction::Opcode::Ret:
142   case Instruction::Opcode::PHI:
143   case Instruction::Opcode::AddrSpaceCast:
144   case Instruction::Opcode::Call:
145   case Instruction::Opcode::GetElementPtr:
146     llvm_unreachable("Unimplemented");
147     break;
148   default:
149     llvm_unreachable("Unimplemented");
150     break;
151   }
152   llvm_unreachable("Missing switch case!");
153   // TODO: Propagate debug info.
154 }
155 
156 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
157   Value *NewVec = nullptr;
158   const auto &LegalityRes = Legality->canVectorize(Bndl);
159   switch (LegalityRes.getSubclassID()) {
160   case LegalityResultID::Widen: {
161     auto *I = cast<Instruction>(Bndl[0]);
162     SmallVector<Value *, 2> VecOperands;
163     switch (I->getOpcode()) {
164     case Instruction::Opcode::Load:
165       // Don't recurse towards the pointer operand.
166       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
167       break;
168     case Instruction::Opcode::Store: {
169       // Don't recurse towards the pointer operand.
170       auto *VecOp = vectorizeRec(getOperand(Bndl, 0));
171       VecOperands.push_back(VecOp);
172       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
173       break;
174     }
175     default:
176       // Visit all operands.
177       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
178         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx));
179         VecOperands.push_back(VecOp);
180       }
181       break;
182     }
183     NewVec = createVectorInstr(Bndl, VecOperands);
184 
185     // TODO: Collect potentially dead instructions.
186     break;
187   }
188   case LegalityResultID::Pack: {
189     // TODO: Unimplemented
190     llvm_unreachable("Unimplemented");
191   }
192   }
193   return NewVec;
194 }
195 
196 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
197   vectorizeRec(Bndl);
198   return Change;
199 }
200 
201 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
202   Legality = std::make_unique<LegalityAnalysis>(
203       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
204       F.getContext());
205   Change = false;
206   // TODO: Start from innermost BBs first
207   for (auto &BB : F) {
208     // TODO: Replace with proper SeedCollector function.
209     auto Seeds = collectSeeds(BB);
210     // TODO: Slice Seeds into smaller chunks.
211     // TODO: If vectorization succeeds, run the RegionPassManager on the
212     // resulting region.
213     if (Seeds.size() >= 2)
214       Change |= tryVectorize(Seeds);
215   }
216   return Change;
217 }
218 
219 } // namespace llvm::sandboxir
220