1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/SandboxIR/Function.h" 12 #include "llvm/SandboxIR/Instruction.h" 13 #include "llvm/SandboxIR/Module.h" 14 #include "llvm/SandboxIR/Utils.h" 15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 17 18 namespace llvm::sandboxir { 19 20 BottomUpVec::BottomUpVec(StringRef Pipeline) 21 : FunctionPass("bottom-up-vec"), 22 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 23 24 // TODO: This is a temporary function that returns some seeds. 25 // Replace this with SeedCollector's function when it lands. 26 static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) { 27 llvm::SmallVector<Value *, 4> Seeds; 28 for (auto &I : BB) 29 if (auto *SI = llvm::dyn_cast<StoreInst>(&I)) 30 Seeds.push_back(SI); 31 return Seeds; 32 } 33 34 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 35 unsigned OpIdx) { 36 SmallVector<Value *, 4> Operands; 37 for (Value *BndlV : Bndl) { 38 auto *BndlI = cast<Instruction>(BndlV); 39 Operands.push_back(BndlI->getOperand(OpIdx)); 40 } 41 return Operands; 42 } 43 44 static BasicBlock::iterator 45 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) { 46 // TODO: Use the VecUtils function for getting the bottom instr once it lands. 47 auto *BotI = cast<Instruction>( 48 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { 49 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2)); 50 })); 51 // If Bndl contains Arguments or Constants, use the beginning of the BB. 52 return std::next(BotI->getIterator()); 53 } 54 55 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 56 ArrayRef<Value *> Operands) { 57 Change = true; 58 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 59 "Expect Instructions!"); 60 auto &Ctx = Bndl[0]->getContext(); 61 62 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 63 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 64 65 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); 66 67 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 68 switch (Opcode) { 69 case Instruction::Opcode::ZExt: 70 case Instruction::Opcode::SExt: 71 case Instruction::Opcode::FPToUI: 72 case Instruction::Opcode::FPToSI: 73 case Instruction::Opcode::FPExt: 74 case Instruction::Opcode::PtrToInt: 75 case Instruction::Opcode::IntToPtr: 76 case Instruction::Opcode::SIToFP: 77 case Instruction::Opcode::UIToFP: 78 case Instruction::Opcode::Trunc: 79 case Instruction::Opcode::FPTrunc: 80 case Instruction::Opcode::BitCast: { 81 assert(Operands.size() == 1u && "Casts are unary!"); 82 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast"); 83 } 84 case Instruction::Opcode::FCmp: 85 case Instruction::Opcode::ICmp: { 86 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 87 assert(all_of(drop_begin(Bndl), 88 [Pred](auto *SBV) { 89 return cast<CmpInst>(SBV)->getPredicate() == Pred; 90 }) && 91 "Expected same predicate across bundle."); 92 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 93 "VCmp"); 94 } 95 case Instruction::Opcode::Select: { 96 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 97 Ctx, "Vec"); 98 } 99 case Instruction::Opcode::FNeg: { 100 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 101 auto OpC = UOp0->getOpcode(); 102 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt, 103 Ctx, "Vec"); 104 } 105 case Instruction::Opcode::Add: 106 case Instruction::Opcode::FAdd: 107 case Instruction::Opcode::Sub: 108 case Instruction::Opcode::FSub: 109 case Instruction::Opcode::Mul: 110 case Instruction::Opcode::FMul: 111 case Instruction::Opcode::UDiv: 112 case Instruction::Opcode::SDiv: 113 case Instruction::Opcode::FDiv: 114 case Instruction::Opcode::URem: 115 case Instruction::Opcode::SRem: 116 case Instruction::Opcode::FRem: 117 case Instruction::Opcode::Shl: 118 case Instruction::Opcode::LShr: 119 case Instruction::Opcode::AShr: 120 case Instruction::Opcode::And: 121 case Instruction::Opcode::Or: 122 case Instruction::Opcode::Xor: { 123 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 124 auto *LHS = Operands[0]; 125 auto *RHS = Operands[1]; 126 return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS, 127 BinOp0, WhereIt, Ctx, "Vec"); 128 } 129 case Instruction::Opcode::Load: { 130 auto *Ld0 = cast<LoadInst>(Bndl[0]); 131 Value *Ptr = Ld0->getPointerOperand(); 132 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL"); 133 } 134 case Instruction::Opcode::Store: { 135 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 136 Value *Val = Operands[0]; 137 Value *Ptr = Operands[1]; 138 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 139 } 140 case Instruction::Opcode::Br: 141 case Instruction::Opcode::Ret: 142 case Instruction::Opcode::PHI: 143 case Instruction::Opcode::AddrSpaceCast: 144 case Instruction::Opcode::Call: 145 case Instruction::Opcode::GetElementPtr: 146 llvm_unreachable("Unimplemented"); 147 break; 148 default: 149 llvm_unreachable("Unimplemented"); 150 break; 151 } 152 llvm_unreachable("Missing switch case!"); 153 // TODO: Propagate debug info. 154 } 155 156 void BottomUpVec::tryEraseDeadInstrs() { 157 // Visiting the dead instructions bottom-to-top. 158 sort(DeadInstrCandidates, 159 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 160 for (Instruction *I : reverse(DeadInstrCandidates)) { 161 if (I->hasNUses(0)) 162 I->eraseFromParent(); 163 } 164 DeadInstrCandidates.clear(); 165 } 166 167 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) { 168 Value *NewVec = nullptr; 169 const auto &LegalityRes = Legality->canVectorize(Bndl); 170 switch (LegalityRes.getSubclassID()) { 171 case LegalityResultID::Widen: { 172 auto *I = cast<Instruction>(Bndl[0]); 173 SmallVector<Value *, 2> VecOperands; 174 switch (I->getOpcode()) { 175 case Instruction::Opcode::Load: 176 // Don't recurse towards the pointer operand. 177 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 178 break; 179 case Instruction::Opcode::Store: { 180 // Don't recurse towards the pointer operand. 181 auto *VecOp = vectorizeRec(getOperand(Bndl, 0)); 182 VecOperands.push_back(VecOp); 183 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 184 break; 185 } 186 default: 187 // Visit all operands. 188 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 189 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx)); 190 VecOperands.push_back(VecOp); 191 } 192 break; 193 } 194 NewVec = createVectorInstr(Bndl, VecOperands); 195 196 // Collect the original scalar instructions as they may be dead. 197 if (NewVec != nullptr) { 198 for (Value *V : Bndl) 199 DeadInstrCandidates.push_back(cast<Instruction>(V)); 200 } 201 break; 202 } 203 case LegalityResultID::Pack: { 204 // TODO: Unimplemented 205 llvm_unreachable("Unimplemented"); 206 } 207 } 208 return NewVec; 209 } 210 211 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 212 DeadInstrCandidates.clear(); 213 vectorizeRec(Bndl); 214 tryEraseDeadInstrs(); 215 return Change; 216 } 217 218 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 219 Legality = std::make_unique<LegalityAnalysis>( 220 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 221 F.getContext()); 222 Change = false; 223 // TODO: Start from innermost BBs first 224 for (auto &BB : F) { 225 // TODO: Replace with proper SeedCollector function. 226 auto Seeds = collectSeeds(BB); 227 // TODO: Slice Seeds into smaller chunks. 228 // TODO: If vectorization succeeds, run the RegionPassManager on the 229 // resulting region. 230 if (Seeds.size() >= 2) 231 Change |= tryVectorize(Seeds); 232 } 233 return Change; 234 } 235 236 } // namespace llvm::sandboxir 237