xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision fd087135efe1b62b506c3caef3fef83242a8e504)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19 
20 namespace llvm {
21 
22 static cl::opt<unsigned>
23     OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24                        cl::desc("Override the vector register size in bits, "
25                                 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27     AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28                  cl::desc("Allow non-power-of-2 vectorization."));
29 
30 namespace sandboxir {
31 
32 BottomUpVec::BottomUpVec(StringRef Pipeline)
33     : FunctionPass("bottom-up-vec"),
34       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
35 
36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
37                                           unsigned OpIdx) {
38   SmallVector<Value *, 4> Operands;
39   for (Value *BndlV : Bndl) {
40     auto *BndlI = cast<Instruction>(BndlV);
41     Operands.push_back(BndlI->getOperand(OpIdx));
42   }
43   return Operands;
44 }
45 
46 static BasicBlock::iterator
47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
48   // TODO: Use the VecUtils function for getting the bottom instr once it lands.
49   auto *BotI = cast<Instruction>(
50       *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
51         return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
52       }));
53   // If Bndl contains Arguments or Constants, use the beginning of the BB.
54   return std::next(BotI->getIterator());
55 }
56 
57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
58                                       ArrayRef<Value *> Operands) {
59   auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
60                               ArrayRef<Value *> Operands) -> Value * {
61     assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
62            "Expect Instructions!");
63     auto &Ctx = Bndl[0]->getContext();
64 
65     Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
66     auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
67 
68     BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
69 
70     auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
71     switch (Opcode) {
72     case Instruction::Opcode::ZExt:
73     case Instruction::Opcode::SExt:
74     case Instruction::Opcode::FPToUI:
75     case Instruction::Opcode::FPToSI:
76     case Instruction::Opcode::FPExt:
77     case Instruction::Opcode::PtrToInt:
78     case Instruction::Opcode::IntToPtr:
79     case Instruction::Opcode::SIToFP:
80     case Instruction::Opcode::UIToFP:
81     case Instruction::Opcode::Trunc:
82     case Instruction::Opcode::FPTrunc:
83     case Instruction::Opcode::BitCast: {
84       assert(Operands.size() == 1u && "Casts are unary!");
85       return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
86                               "VCast");
87     }
88     case Instruction::Opcode::FCmp:
89     case Instruction::Opcode::ICmp: {
90       auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
91       assert(all_of(drop_begin(Bndl),
92                     [Pred](auto *SBV) {
93                       return cast<CmpInst>(SBV)->getPredicate() == Pred;
94                     }) &&
95              "Expected same predicate across bundle.");
96       return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
97                              "VCmp");
98     }
99     case Instruction::Opcode::Select: {
100       return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
101                                 Ctx, "Vec");
102     }
103     case Instruction::Opcode::FNeg: {
104       auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
105       auto OpC = UOp0->getOpcode();
106       return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0,
107                                                   WhereIt, Ctx, "Vec");
108     }
109     case Instruction::Opcode::Add:
110     case Instruction::Opcode::FAdd:
111     case Instruction::Opcode::Sub:
112     case Instruction::Opcode::FSub:
113     case Instruction::Opcode::Mul:
114     case Instruction::Opcode::FMul:
115     case Instruction::Opcode::UDiv:
116     case Instruction::Opcode::SDiv:
117     case Instruction::Opcode::FDiv:
118     case Instruction::Opcode::URem:
119     case Instruction::Opcode::SRem:
120     case Instruction::Opcode::FRem:
121     case Instruction::Opcode::Shl:
122     case Instruction::Opcode::LShr:
123     case Instruction::Opcode::AShr:
124     case Instruction::Opcode::And:
125     case Instruction::Opcode::Or:
126     case Instruction::Opcode::Xor: {
127       auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
128       auto *LHS = Operands[0];
129       auto *RHS = Operands[1];
130       return BinaryOperator::createWithCopiedFlags(
131           BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
132     }
133     case Instruction::Opcode::Load: {
134       auto *Ld0 = cast<LoadInst>(Bndl[0]);
135       Value *Ptr = Ld0->getPointerOperand();
136       return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
137                               "VecL");
138     }
139     case Instruction::Opcode::Store: {
140       auto Align = cast<StoreInst>(Bndl[0])->getAlign();
141       Value *Val = Operands[0];
142       Value *Ptr = Operands[1];
143       return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
144     }
145     case Instruction::Opcode::Br:
146     case Instruction::Opcode::Ret:
147     case Instruction::Opcode::PHI:
148     case Instruction::Opcode::AddrSpaceCast:
149     case Instruction::Opcode::Call:
150     case Instruction::Opcode::GetElementPtr:
151       llvm_unreachable("Unimplemented");
152       break;
153     default:
154       llvm_unreachable("Unimplemented");
155       break;
156     }
157     llvm_unreachable("Missing switch case!");
158     // TODO: Propagate debug info.
159   };
160 
161   auto *VecI = CreateVectorInstr(Bndl, Operands);
162   if (VecI != nullptr) {
163     Change = true;
164     IMaps->registerVector(Bndl, VecI);
165   }
166   return VecI;
167 }
168 
169 void BottomUpVec::tryEraseDeadInstrs() {
170   // Visiting the dead instructions bottom-to-top.
171   SmallVector<Instruction *> SortedDeadInstrCandidates(
172       DeadInstrCandidates.begin(), DeadInstrCandidates.end());
173   sort(SortedDeadInstrCandidates,
174        [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
175   for (Instruction *I : reverse(SortedDeadInstrCandidates)) {
176     if (I->hasNUses(0))
177       I->eraseFromParent();
178   }
179   DeadInstrCandidates.clear();
180 }
181 
182 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) {
183   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp});
184   return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
185                                    VecOp->getContext(), "VShuf");
186 }
187 
188 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
189   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
190 
191   Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
192   unsigned Lanes = VecUtils::getNumLanes(ToPack);
193   Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
194 
195   // Create a series of pack instructions.
196   Value *LastInsert = PoisonValue::get(VecTy);
197 
198   Context &Ctx = ToPack[0]->getContext();
199 
200   unsigned InsertIdx = 0;
201   for (Value *Elm : ToPack) {
202     // An element can be either scalar or vector. We need to generate different
203     // IR for each case.
204     if (Elm->getType()->isVectorTy()) {
205       unsigned NumElms =
206           cast<FixedVectorType>(Elm->getType())->getNumElements();
207       for (auto ExtrLane : seq<int>(0, NumElms)) {
208         // We generate extract-insert pairs, for each lane in `Elm`.
209         Constant *ExtrLaneC =
210             ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
211         // This may return a Constant if Elm is a Constant.
212         auto *ExtrI =
213             ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
214         if (!isa<Constant>(ExtrI))
215           WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
216         Constant *InsertLaneC =
217             ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
218         // This may also return a Constant if ExtrI is a Constant.
219         auto *InsertI = InsertElementInst::create(
220             LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
221         if (!isa<Constant>(InsertI)) {
222           LastInsert = InsertI;
223           WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
224         }
225       }
226     } else {
227       Constant *InsertLaneC =
228           ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
229       // This may be folded into a Constant if LastInsert is a Constant. In
230       // that case we only collect the last constant.
231       LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
232                                              WhereIt, Ctx, "Pack");
233       if (auto *NewI = dyn_cast<Instruction>(LastInsert))
234         WhereIt = std::next(NewI->getIterator());
235     }
236   }
237   return LastInsert;
238 }
239 
240 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
241   for (Value *V : Bndl)
242     DeadInstrCandidates.insert(cast<Instruction>(V));
243   // Also collect the GEPs of vectorized loads and stores.
244   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
245   switch (Opcode) {
246   case Instruction::Opcode::Load: {
247     for (Value *V : drop_begin(Bndl))
248       if (auto *Ptr =
249               dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
250         DeadInstrCandidates.insert(Ptr);
251     break;
252   }
253   case Instruction::Opcode::Store: {
254     for (Value *V : drop_begin(Bndl))
255       if (auto *Ptr =
256               dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
257         DeadInstrCandidates.insert(Ptr);
258     break;
259   }
260   default:
261     break;
262   }
263 }
264 
265 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
266   Value *NewVec = nullptr;
267   const auto &LegalityRes = Legality->canVectorize(Bndl);
268   switch (LegalityRes.getSubclassID()) {
269   case LegalityResultID::Widen: {
270     auto *I = cast<Instruction>(Bndl[0]);
271     SmallVector<Value *, 2> VecOperands;
272     switch (I->getOpcode()) {
273     case Instruction::Opcode::Load:
274       // Don't recurse towards the pointer operand.
275       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
276       break;
277     case Instruction::Opcode::Store: {
278       // Don't recurse towards the pointer operand.
279       auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
280       VecOperands.push_back(VecOp);
281       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
282       break;
283     }
284     default:
285       // Visit all operands.
286       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
287         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
288         VecOperands.push_back(VecOp);
289       }
290       break;
291     }
292     NewVec = createVectorInstr(Bndl, VecOperands);
293 
294     // Collect any potentially dead scalar instructions, including the original
295     // scalars and pointer operands of loads/stores.
296     if (NewVec != nullptr)
297       collectPotentiallyDeadInstrs(Bndl);
298     break;
299   }
300   case LegalityResultID::DiamondReuse: {
301     NewVec = cast<DiamondReuse>(LegalityRes).getVector();
302     break;
303   }
304   case LegalityResultID::DiamondReuseWithShuffle: {
305     auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
306     const ShuffleMask &Mask =
307         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
308     NewVec = createShuffle(VecOp, Mask);
309     break;
310   }
311   case LegalityResultID::DiamondReuseMultiInput: {
312     const auto &Descr =
313         cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
314     Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
315 
316     // TODO: Try to get WhereIt without creating a vector.
317     SmallVector<Value *, 4> DescrInstrs;
318     for (const auto &ElmDescr : Descr.getDescrs()) {
319       if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
320         DescrInstrs.push_back(I);
321     }
322     auto WhereIt = getInsertPointAfterInstrs(DescrInstrs);
323 
324     Value *LastV = PoisonValue::get(ResTy);
325     for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
326       Value *VecOp = ElmDescr.getValue();
327       Context &Ctx = VecOp->getContext();
328       Value *ValueToInsert;
329       if (ElmDescr.needsExtract()) {
330         ConstantInt *IdxC =
331             ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
332         ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
333                                                    VecOp->getContext(), "VExt");
334       } else {
335         ValueToInsert = VecOp;
336       }
337       ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
338       Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
339                                              WhereIt, Ctx, "VIns");
340       LastV = Ins;
341     }
342     NewVec = LastV;
343     break;
344   }
345   case LegalityResultID::Pack: {
346     // If we can't vectorize the seeds then just return.
347     if (Depth == 0)
348       return nullptr;
349     NewVec = createPack(Bndl);
350     break;
351   }
352   }
353   return NewVec;
354 }
355 
356 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
357   DeadInstrCandidates.clear();
358   Legality->clear();
359   vectorizeRec(Bndl, /*Depth=*/0);
360   tryEraseDeadInstrs();
361   return Change;
362 }
363 
364 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
365   IMaps = std::make_unique<InstrMaps>(F.getContext());
366   Legality = std::make_unique<LegalityAnalysis>(
367       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
368       F.getContext(), *IMaps);
369   Change = false;
370   const auto &DL = F.getParent()->getDataLayout();
371   unsigned VecRegBits =
372       OverrideVecRegBits != 0
373           ? OverrideVecRegBits
374           : A.getTTI()
375                 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
376                 .getFixedValue();
377 
378   // TODO: Start from innermost BBs first
379   for (auto &BB : F) {
380     SeedCollector SC(&BB, A.getScalarEvolution());
381     for (SeedBundle &Seeds : SC.getStoreSeeds()) {
382       unsigned ElmBits =
383           Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
384                                 Seeds[Seeds.getFirstUnusedElementIdx()])),
385                             DL);
386 
387       auto DivideBy2 = [](unsigned Num) {
388         auto Floor = VecUtils::getFloorPowerOf2(Num);
389         if (Floor == Num)
390           return Floor / 2;
391         return Floor;
392       };
393       // Try to create the largest vector supported by the target. If it fails
394       // reduce the vector size by half.
395       for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
396                                          Seeds.getNumUnusedBits() / ElmBits);
397            SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
398         if (Seeds.allUsed())
399           break;
400         // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
401         // the slice. This could be quite expensive, so we enforce a limit.
402         for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
403                       OE = Seeds.size();
404              Offset + 1 < OE; Offset += 1) {
405           // Seeds are getting used as we vectorize, so skip them.
406           if (Seeds.isUsed(Offset))
407             continue;
408           if (Seeds.allUsed())
409             break;
410 
411           auto SeedSlice =
412               Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
413           if (SeedSlice.empty())
414             continue;
415 
416           assert(SeedSlice.size() >= 2 && "Should have been rejected!");
417 
418           // TODO: If vectorization succeeds, run the RegionPassManager on the
419           // resulting region.
420 
421           // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
422           SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
423                                              SeedSlice.end());
424           Change |= tryVectorize(SeedSliceVals);
425         }
426       }
427     }
428   }
429   return Change;
430 }
431 
432 } // namespace sandboxir
433 } // namespace llvm
434