1349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2349cc55cSDimitry Andric // 3349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6349cc55cSDimitry Andric // 7349cc55cSDimitry Andric //===----------------------------------------------------------------------===// 8349cc55cSDimitry Andric // 9349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to 1006c3fb27SDimitry Andric // RISC-V intrinsics. 11349cc55cSDimitry Andric // 12349cc55cSDimitry Andric //===----------------------------------------------------------------------===// 13349cc55cSDimitry Andric 14349cc55cSDimitry Andric #include "RISCV.h" 15349cc55cSDimitry Andric #include "RISCVTargetMachine.h" 1606c3fb27SDimitry Andric #include "llvm/Analysis/InstSimplifyFolder.h" 17349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h" 18349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 19349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h" 20349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 21349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h" 22349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h" 23349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h" 24349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h" 25bdd1243dSDimitry Andric #include "llvm/IR/PatternMatch.h" 26349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 27bdd1243dSDimitry Andric #include <optional> 28349cc55cSDimitry Andric 29349cc55cSDimitry Andric using namespace llvm; 30bdd1243dSDimitry Andric using namespace PatternMatch; 31349cc55cSDimitry Andric 32349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering" 33349cc55cSDimitry Andric 34349cc55cSDimitry Andric namespace { 35349cc55cSDimitry Andric 36349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass { 37349cc55cSDimitry Andric const RISCVSubtarget *ST = nullptr; 38349cc55cSDimitry Andric const RISCVTargetLowering *TLI = nullptr; 39349cc55cSDimitry Andric LoopInfo *LI = nullptr; 40349cc55cSDimitry Andric const DataLayout *DL = nullptr; 41349cc55cSDimitry Andric 42349cc55cSDimitry Andric SmallVector<WeakTrackingVH> MaybeDeadPHIs; 43349cc55cSDimitry Andric 4481ad6265SDimitry Andric // Cache of the BasePtr and Stride determined from this GEP. When a GEP is 4581ad6265SDimitry Andric // used by multiple gathers/scatters, this allow us to reuse the scalar 4681ad6265SDimitry Andric // instructions we created for the first gather/scatter for the others. 4781ad6265SDimitry Andric DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; 4881ad6265SDimitry Andric 49349cc55cSDimitry Andric public: 50349cc55cSDimitry Andric static char ID; // Pass identification, replacement for typeid 51349cc55cSDimitry Andric 52349cc55cSDimitry Andric RISCVGatherScatterLowering() : FunctionPass(ID) {} 53349cc55cSDimitry Andric 54349cc55cSDimitry Andric bool runOnFunction(Function &F) override; 55349cc55cSDimitry Andric 56349cc55cSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 57349cc55cSDimitry Andric AU.setPreservesCFG(); 58349cc55cSDimitry Andric AU.addRequired<TargetPassConfig>(); 59349cc55cSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 60349cc55cSDimitry Andric } 61349cc55cSDimitry Andric 62349cc55cSDimitry Andric StringRef getPassName() const override { 6306c3fb27SDimitry Andric return "RISC-V gather/scatter lowering"; 64349cc55cSDimitry Andric } 65349cc55cSDimitry Andric 66349cc55cSDimitry Andric private: 67349cc55cSDimitry Andric bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 68349cc55cSDimitry Andric Value *AlignOp); 69349cc55cSDimitry Andric 705f757f3fSDimitry Andric std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr, 7106c3fb27SDimitry Andric IRBuilderBase &Builder); 72349cc55cSDimitry Andric 73349cc55cSDimitry Andric bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 74349cc55cSDimitry Andric PHINode *&BasePtr, BinaryOperator *&Inc, 7506c3fb27SDimitry Andric IRBuilderBase &Builder); 76349cc55cSDimitry Andric }; 77349cc55cSDimitry Andric 78349cc55cSDimitry Andric } // end anonymous namespace 79349cc55cSDimitry Andric 80349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0; 81349cc55cSDimitry Andric 82349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 8306c3fb27SDimitry Andric "RISC-V gather/scatter lowering pass", false, false) 84349cc55cSDimitry Andric 85349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 86349cc55cSDimitry Andric return new RISCVGatherScatterLowering(); 87349cc55cSDimitry Andric } 88349cc55cSDimitry Andric 89349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride? 90349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 9106c3fb27SDimitry Andric if (!isa<FixedVectorType>(StartC->getType())) 9206c3fb27SDimitry Andric return std::make_pair(nullptr, nullptr); 9306c3fb27SDimitry Andric 94349cc55cSDimitry Andric unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 95349cc55cSDimitry Andric 96349cc55cSDimitry Andric // Check that the start value is a strided constant. 97349cc55cSDimitry Andric auto *StartVal = 98349cc55cSDimitry Andric dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 99349cc55cSDimitry Andric if (!StartVal) 100349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 101349cc55cSDimitry Andric APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 102349cc55cSDimitry Andric ConstantInt *Prev = StartVal; 103349cc55cSDimitry Andric for (unsigned i = 1; i != NumElts; ++i) { 104349cc55cSDimitry Andric auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 105349cc55cSDimitry Andric if (!C) 106349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 107349cc55cSDimitry Andric 108349cc55cSDimitry Andric APInt LocalStride = C->getValue() - Prev->getValue(); 109349cc55cSDimitry Andric if (i == 1) 110349cc55cSDimitry Andric StrideVal = LocalStride; 111349cc55cSDimitry Andric else if (StrideVal != LocalStride) 112349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 113349cc55cSDimitry Andric 114349cc55cSDimitry Andric Prev = C; 115349cc55cSDimitry Andric } 116349cc55cSDimitry Andric 117349cc55cSDimitry Andric Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 118349cc55cSDimitry Andric 119349cc55cSDimitry Andric return std::make_pair(StartVal, Stride); 120349cc55cSDimitry Andric } 121349cc55cSDimitry Andric 12204eeddc0SDimitry Andric static std::pair<Value *, Value *> matchStridedStart(Value *Start, 12306c3fb27SDimitry Andric IRBuilderBase &Builder) { 12404eeddc0SDimitry Andric // Base case, start is a strided constant. 12504eeddc0SDimitry Andric auto *StartC = dyn_cast<Constant>(Start); 12604eeddc0SDimitry Andric if (StartC) 12704eeddc0SDimitry Andric return matchStridedConstant(StartC); 12804eeddc0SDimitry Andric 129bdd1243dSDimitry Andric // Base case, start is a stepvector 130bdd1243dSDimitry Andric if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) { 131bdd1243dSDimitry Andric auto *Ty = Start->getType()->getScalarType(); 132bdd1243dSDimitry Andric return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); 133bdd1243dSDimitry Andric } 134bdd1243dSDimitry Andric 13506c3fb27SDimitry Andric // Not a constant, maybe it's a strided constant with a splat added or 13606c3fb27SDimitry Andric // multipled. 13704eeddc0SDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Start); 13806c3fb27SDimitry Andric if (!BO || (BO->getOpcode() != Instruction::Add && 1397a6dacacSDimitry Andric BO->getOpcode() != Instruction::Or && 14006c3fb27SDimitry Andric BO->getOpcode() != Instruction::Shl && 14106c3fb27SDimitry Andric BO->getOpcode() != Instruction::Mul)) 14204eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr); 14304eeddc0SDimitry Andric 1447a6dacacSDimitry Andric if (BO->getOpcode() == Instruction::Or && 1457a6dacacSDimitry Andric !cast<PossiblyDisjointInst>(BO)->isDisjoint()) 1467a6dacacSDimitry Andric return std::make_pair(nullptr, nullptr); 1477a6dacacSDimitry Andric 14804eeddc0SDimitry Andric // Look for an operand that is splatted. 14906c3fb27SDimitry Andric unsigned OtherIndex = 0; 15006c3fb27SDimitry Andric Value *Splat = getSplatValue(BO->getOperand(1)); 15106c3fb27SDimitry Andric if (!Splat && Instruction::isCommutative(BO->getOpcode())) { 15206c3fb27SDimitry Andric Splat = getSplatValue(BO->getOperand(0)); 15306c3fb27SDimitry Andric OtherIndex = 1; 15404eeddc0SDimitry Andric } 15504eeddc0SDimitry Andric if (!Splat) 15604eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr); 15704eeddc0SDimitry Andric 15804eeddc0SDimitry Andric Value *Stride; 15904eeddc0SDimitry Andric std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 16004eeddc0SDimitry Andric Builder); 16104eeddc0SDimitry Andric if (!Start) 16204eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr); 16304eeddc0SDimitry Andric 16404eeddc0SDimitry Andric Builder.SetInsertPoint(BO); 16504eeddc0SDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc()); 16606c3fb27SDimitry Andric // Add the splat value to the start or multiply the start and stride by the 16706c3fb27SDimitry Andric // splat. 16806c3fb27SDimitry Andric switch (BO->getOpcode()) { 16906c3fb27SDimitry Andric default: 17006c3fb27SDimitry Andric llvm_unreachable("Unexpected opcode"); 1717a6dacacSDimitry Andric case Instruction::Or: 1727a6dacacSDimitry Andric // TODO: We'd be better off creating disjoint or here, but we don't yet 1737a6dacacSDimitry Andric // have an IRBuilder API for that. 1747a6dacacSDimitry Andric [[fallthrough]]; 17506c3fb27SDimitry Andric case Instruction::Add: 17604eeddc0SDimitry Andric Start = Builder.CreateAdd(Start, Splat); 17706c3fb27SDimitry Andric break; 17806c3fb27SDimitry Andric case Instruction::Mul: 17906c3fb27SDimitry Andric Start = Builder.CreateMul(Start, Splat); 18006c3fb27SDimitry Andric Stride = Builder.CreateMul(Stride, Splat); 18106c3fb27SDimitry Andric break; 18206c3fb27SDimitry Andric case Instruction::Shl: 18306c3fb27SDimitry Andric Start = Builder.CreateShl(Start, Splat); 18406c3fb27SDimitry Andric Stride = Builder.CreateShl(Stride, Splat); 18506c3fb27SDimitry Andric break; 18606c3fb27SDimitry Andric } 18706c3fb27SDimitry Andric 18804eeddc0SDimitry Andric return std::make_pair(Start, Stride); 18904eeddc0SDimitry Andric } 19004eeddc0SDimitry Andric 191349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided 192349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion. 193349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the 194349cc55cSDimitry Andric // arithmetic out of the loop. 195349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 196349cc55cSDimitry Andric Value *&Stride, 197349cc55cSDimitry Andric PHINode *&BasePtr, 198349cc55cSDimitry Andric BinaryOperator *&Inc, 19906c3fb27SDimitry Andric IRBuilderBase &Builder) { 200349cc55cSDimitry Andric // Our base case is a Phi. 201349cc55cSDimitry Andric if (auto *Phi = dyn_cast<PHINode>(Index)) { 202349cc55cSDimitry Andric // A phi node we want to perform this function on should be from the 203349cc55cSDimitry Andric // loop header. 204349cc55cSDimitry Andric if (Phi->getParent() != L->getHeader()) 205349cc55cSDimitry Andric return false; 206349cc55cSDimitry Andric 207349cc55cSDimitry Andric Value *Step, *Start; 208349cc55cSDimitry Andric if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 209349cc55cSDimitry Andric Inc->getOpcode() != Instruction::Add) 210349cc55cSDimitry Andric return false; 211349cc55cSDimitry Andric assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 212349cc55cSDimitry Andric unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 213349cc55cSDimitry Andric assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 214349cc55cSDimitry Andric "Expected one operand of phi to be Inc"); 215349cc55cSDimitry Andric 216349cc55cSDimitry Andric // Only proceed if the step is loop invariant. 217349cc55cSDimitry Andric if (!L->isLoopInvariant(Step)) 218349cc55cSDimitry Andric return false; 219349cc55cSDimitry Andric 220349cc55cSDimitry Andric // Step should be a splat. 221349cc55cSDimitry Andric Step = getSplatValue(Step); 222349cc55cSDimitry Andric if (!Step) 223349cc55cSDimitry Andric return false; 224349cc55cSDimitry Andric 22504eeddc0SDimitry Andric std::tie(Start, Stride) = matchStridedStart(Start, Builder); 226349cc55cSDimitry Andric if (!Start) 227349cc55cSDimitry Andric return false; 228349cc55cSDimitry Andric assert(Stride != nullptr); 229349cc55cSDimitry Andric 230349cc55cSDimitry Andric // Build scalar phi and increment. 231349cc55cSDimitry Andric BasePtr = 232*0fca6ea1SDimitry Andric PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator()); 233349cc55cSDimitry Andric Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 234*0fca6ea1SDimitry Andric Inc->getIterator()); 235349cc55cSDimitry Andric BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 236349cc55cSDimitry Andric BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 237349cc55cSDimitry Andric 238349cc55cSDimitry Andric // Note that this Phi might be eligible for removal. 239349cc55cSDimitry Andric MaybeDeadPHIs.push_back(Phi); 240349cc55cSDimitry Andric return true; 241349cc55cSDimitry Andric } 242349cc55cSDimitry Andric 243349cc55cSDimitry Andric // Otherwise look for binary operator. 244349cc55cSDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Index); 245349cc55cSDimitry Andric if (!BO) 246349cc55cSDimitry Andric return false; 247349cc55cSDimitry Andric 24806c3fb27SDimitry Andric switch (BO->getOpcode()) { 24906c3fb27SDimitry Andric default: 250349cc55cSDimitry Andric return false; 25106c3fb27SDimitry Andric case Instruction::Or: 252349cc55cSDimitry Andric // We need to be able to treat Or as Add. 2537a6dacacSDimitry Andric if (!cast<PossiblyDisjointInst>(BO)->isDisjoint()) 254349cc55cSDimitry Andric return false; 25506c3fb27SDimitry Andric break; 25606c3fb27SDimitry Andric case Instruction::Add: 25706c3fb27SDimitry Andric break; 25806c3fb27SDimitry Andric case Instruction::Shl: 25906c3fb27SDimitry Andric break; 26006c3fb27SDimitry Andric case Instruction::Mul: 26106c3fb27SDimitry Andric break; 26206c3fb27SDimitry Andric } 263349cc55cSDimitry Andric 264349cc55cSDimitry Andric // We should have one operand in the loop and one splat. 265349cc55cSDimitry Andric Value *OtherOp; 266349cc55cSDimitry Andric if (isa<Instruction>(BO->getOperand(0)) && 267349cc55cSDimitry Andric L->contains(cast<Instruction>(BO->getOperand(0)))) { 268349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(0)); 269349cc55cSDimitry Andric OtherOp = BO->getOperand(1); 270349cc55cSDimitry Andric } else if (isa<Instruction>(BO->getOperand(1)) && 27106c3fb27SDimitry Andric L->contains(cast<Instruction>(BO->getOperand(1))) && 27206c3fb27SDimitry Andric Instruction::isCommutative(BO->getOpcode())) { 273349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(1)); 274349cc55cSDimitry Andric OtherOp = BO->getOperand(0); 275349cc55cSDimitry Andric } else { 276349cc55cSDimitry Andric return false; 277349cc55cSDimitry Andric } 278349cc55cSDimitry Andric 279349cc55cSDimitry Andric // Make sure other op is loop invariant. 280349cc55cSDimitry Andric if (!L->isLoopInvariant(OtherOp)) 281349cc55cSDimitry Andric return false; 282349cc55cSDimitry Andric 283349cc55cSDimitry Andric // Make sure we have a splat. 284349cc55cSDimitry Andric Value *SplatOp = getSplatValue(OtherOp); 285349cc55cSDimitry Andric if (!SplatOp) 286349cc55cSDimitry Andric return false; 287349cc55cSDimitry Andric 288349cc55cSDimitry Andric // Recurse up the use-def chain. 289349cc55cSDimitry Andric if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 290349cc55cSDimitry Andric return false; 291349cc55cSDimitry Andric 292349cc55cSDimitry Andric // Locate the Step and Start values from the recurrence. 293349cc55cSDimitry Andric unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 294349cc55cSDimitry Andric unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 295349cc55cSDimitry Andric Value *Step = Inc->getOperand(StepIndex); 296349cc55cSDimitry Andric Value *Start = BasePtr->getOperand(StartBlock); 297349cc55cSDimitry Andric 298349cc55cSDimitry Andric // We need to adjust the start value in the preheader. 299349cc55cSDimitry Andric Builder.SetInsertPoint( 300349cc55cSDimitry Andric BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 301349cc55cSDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc()); 302349cc55cSDimitry Andric 303349cc55cSDimitry Andric switch (BO->getOpcode()) { 304349cc55cSDimitry Andric default: 305349cc55cSDimitry Andric llvm_unreachable("Unexpected opcode!"); 306349cc55cSDimitry Andric case Instruction::Add: 307349cc55cSDimitry Andric case Instruction::Or: { 308349cc55cSDimitry Andric // An add only affects the start value. It's ok to do this for Or because 309349cc55cSDimitry Andric // we already checked that there are no common set bits. 310349cc55cSDimitry Andric Start = Builder.CreateAdd(Start, SplatOp, "start"); 311349cc55cSDimitry Andric break; 312349cc55cSDimitry Andric } 313349cc55cSDimitry Andric case Instruction::Mul: { 314349cc55cSDimitry Andric Start = Builder.CreateMul(Start, SplatOp, "start"); 315349cc55cSDimitry Andric Step = Builder.CreateMul(Step, SplatOp, "step"); 316349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 317349cc55cSDimitry Andric break; 318349cc55cSDimitry Andric } 319349cc55cSDimitry Andric case Instruction::Shl: { 320349cc55cSDimitry Andric Start = Builder.CreateShl(Start, SplatOp, "start"); 321349cc55cSDimitry Andric Step = Builder.CreateShl(Step, SplatOp, "step"); 322349cc55cSDimitry Andric Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 323349cc55cSDimitry Andric break; 324349cc55cSDimitry Andric } 325349cc55cSDimitry Andric } 326349cc55cSDimitry Andric 32706c3fb27SDimitry Andric Inc->setOperand(StepIndex, Step); 32806c3fb27SDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 329349cc55cSDimitry Andric return true; 330349cc55cSDimitry Andric } 331349cc55cSDimitry Andric 332349cc55cSDimitry Andric std::pair<Value *, Value *> 3335f757f3fSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, 33406c3fb27SDimitry Andric IRBuilderBase &Builder) { 335349cc55cSDimitry Andric 3365f757f3fSDimitry Andric // A gather/scatter of a splat is a zero strided load/store. 3375f757f3fSDimitry Andric if (auto *BasePtr = getSplatValue(Ptr)) { 3385f757f3fSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 3395f757f3fSDimitry Andric return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0)); 3405f757f3fSDimitry Andric } 3415f757f3fSDimitry Andric 3425f757f3fSDimitry Andric auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 3435f757f3fSDimitry Andric if (!GEP) 3445f757f3fSDimitry Andric return std::make_pair(nullptr, nullptr); 3455f757f3fSDimitry Andric 34681ad6265SDimitry Andric auto I = StridedAddrs.find(GEP); 34781ad6265SDimitry Andric if (I != StridedAddrs.end()) 34881ad6265SDimitry Andric return I->second; 34981ad6265SDimitry Andric 350349cc55cSDimitry Andric SmallVector<Value *, 2> Ops(GEP->operands()); 351349cc55cSDimitry Andric 352*0fca6ea1SDimitry Andric // If the base pointer is a vector, check if it's strided. 353*0fca6ea1SDimitry Andric Value *Base = GEP->getPointerOperand(); 354*0fca6ea1SDimitry Andric if (auto *BaseInst = dyn_cast<Instruction>(Base); 355*0fca6ea1SDimitry Andric BaseInst && BaseInst->getType()->isVectorTy()) { 356*0fca6ea1SDimitry Andric // If GEP's offset is scalar then we can add it to the base pointer's base. 357*0fca6ea1SDimitry Andric auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; 358*0fca6ea1SDimitry Andric if (all_of(GEP->indices(), IsScalar)) { 359*0fca6ea1SDimitry Andric auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder); 360*0fca6ea1SDimitry Andric if (BaseBase) { 361*0fca6ea1SDimitry Andric Builder.SetInsertPoint(GEP); 362*0fca6ea1SDimitry Andric SmallVector<Value *> Indices(GEP->indices()); 363*0fca6ea1SDimitry Andric Value *OffsetBase = 364*0fca6ea1SDimitry Andric Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices, 365*0fca6ea1SDimitry Andric GEP->getName() + "offset", GEP->isInBounds()); 366*0fca6ea1SDimitry Andric return {OffsetBase, Stride}; 367*0fca6ea1SDimitry Andric } 368*0fca6ea1SDimitry Andric } 369*0fca6ea1SDimitry Andric } 370*0fca6ea1SDimitry Andric 371349cc55cSDimitry Andric // Base pointer needs to be a scalar. 372*0fca6ea1SDimitry Andric Value *ScalarBase = Base; 3735f757f3fSDimitry Andric if (ScalarBase->getType()->isVectorTy()) { 3745f757f3fSDimitry Andric ScalarBase = getSplatValue(ScalarBase); 3755f757f3fSDimitry Andric if (!ScalarBase) 376349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 3775f757f3fSDimitry Andric } 378349cc55cSDimitry Andric 379bdd1243dSDimitry Andric std::optional<unsigned> VecOperand; 380349cc55cSDimitry Andric unsigned TypeScale = 0; 381349cc55cSDimitry Andric 382349cc55cSDimitry Andric // Look for a vector operand and scale. 383349cc55cSDimitry Andric gep_type_iterator GTI = gep_type_begin(GEP); 384349cc55cSDimitry Andric for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 385349cc55cSDimitry Andric if (!Ops[i]->getType()->isVectorTy()) 386349cc55cSDimitry Andric continue; 387349cc55cSDimitry Andric 388349cc55cSDimitry Andric if (VecOperand) 389349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 390349cc55cSDimitry Andric 391349cc55cSDimitry Andric VecOperand = i; 392349cc55cSDimitry Andric 3931db9f3b2SDimitry Andric TypeSize TS = GTI.getSequentialElementStride(*DL); 394349cc55cSDimitry Andric if (TS.isScalable()) 395349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 396349cc55cSDimitry Andric 397bdd1243dSDimitry Andric TypeScale = TS.getFixedValue(); 398349cc55cSDimitry Andric } 399349cc55cSDimitry Andric 400349cc55cSDimitry Andric // We need to find a vector index to simplify. 401349cc55cSDimitry Andric if (!VecOperand) 402349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 403349cc55cSDimitry Andric 404349cc55cSDimitry Andric // We can't extract the stride if the arithmetic is done at a different size 405349cc55cSDimitry Andric // than the pointer type. Adding the stride later may not wrap correctly. 406349cc55cSDimitry Andric // Technically we could handle wider indices, but I don't expect that in 4075f757f3fSDimitry Andric // practice. Handle one special case here - constants. This simplifies 4085f757f3fSDimitry Andric // writing test cases. 409349cc55cSDimitry Andric Value *VecIndex = Ops[*VecOperand]; 410349cc55cSDimitry Andric Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 4115f757f3fSDimitry Andric if (VecIndex->getType() != VecIntPtrTy) { 4125f757f3fSDimitry Andric auto *VecIndexC = dyn_cast<Constant>(VecIndex); 4135f757f3fSDimitry Andric if (!VecIndexC) 414349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 4155f757f3fSDimitry Andric if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) 4165f757f3fSDimitry Andric VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy); 4175f757f3fSDimitry Andric else 4185f757f3fSDimitry Andric VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy); 4195f757f3fSDimitry Andric } 420349cc55cSDimitry Andric 421bdd1243dSDimitry Andric // Handle the non-recursive case. This is what we see if the vectorizer 422bdd1243dSDimitry Andric // decides to use a scalar IV + vid on demand instead of a vector IV. 423bdd1243dSDimitry Andric auto [Start, Stride] = matchStridedStart(VecIndex, Builder); 424bdd1243dSDimitry Andric if (Start) { 425bdd1243dSDimitry Andric assert(Stride); 426bdd1243dSDimitry Andric Builder.SetInsertPoint(GEP); 427bdd1243dSDimitry Andric 428bdd1243dSDimitry Andric // Replace the vector index with the scalar start and build a scalar GEP. 429bdd1243dSDimitry Andric Ops[*VecOperand] = Start; 430bdd1243dSDimitry Andric Type *SourceTy = GEP->getSourceElementType(); 431bdd1243dSDimitry Andric Value *BasePtr = 4325f757f3fSDimitry Andric Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); 433bdd1243dSDimitry Andric 434bdd1243dSDimitry Andric // Convert stride to pointer size if needed. 435bdd1243dSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 436bdd1243dSDimitry Andric assert(Stride->getType() == IntPtrTy && "Unexpected type"); 437bdd1243dSDimitry Andric 438bdd1243dSDimitry Andric // Scale the stride by the size of the indexed type. 439bdd1243dSDimitry Andric if (TypeScale != 1) 440bdd1243dSDimitry Andric Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 441bdd1243dSDimitry Andric 442bdd1243dSDimitry Andric auto P = std::make_pair(BasePtr, Stride); 443bdd1243dSDimitry Andric StridedAddrs[GEP] = P; 444bdd1243dSDimitry Andric return P; 445bdd1243dSDimitry Andric } 446bdd1243dSDimitry Andric 447bdd1243dSDimitry Andric // Make sure we're in a loop and that has a pre-header and a single latch. 448bdd1243dSDimitry Andric Loop *L = LI->getLoopFor(GEP->getParent()); 449bdd1243dSDimitry Andric if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) 450bdd1243dSDimitry Andric return std::make_pair(nullptr, nullptr); 451bdd1243dSDimitry Andric 452349cc55cSDimitry Andric BinaryOperator *Inc; 453349cc55cSDimitry Andric PHINode *BasePhi; 454349cc55cSDimitry Andric if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 455349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 456349cc55cSDimitry Andric 457349cc55cSDimitry Andric assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 458349cc55cSDimitry Andric unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 459349cc55cSDimitry Andric assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 460349cc55cSDimitry Andric "Expected one operand of phi to be Inc"); 461349cc55cSDimitry Andric 462349cc55cSDimitry Andric Builder.SetInsertPoint(GEP); 463349cc55cSDimitry Andric 464349cc55cSDimitry Andric // Replace the vector index with the scalar phi and build a scalar GEP. 465349cc55cSDimitry Andric Ops[*VecOperand] = BasePhi; 466349cc55cSDimitry Andric Type *SourceTy = GEP->getSourceElementType(); 467349cc55cSDimitry Andric Value *BasePtr = 4685f757f3fSDimitry Andric Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); 469349cc55cSDimitry Andric 470349cc55cSDimitry Andric // Final adjustments to stride should go in the start block. 471349cc55cSDimitry Andric Builder.SetInsertPoint( 472349cc55cSDimitry Andric BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 473349cc55cSDimitry Andric 474349cc55cSDimitry Andric // Convert stride to pointer size if needed. 475349cc55cSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 476349cc55cSDimitry Andric assert(Stride->getType() == IntPtrTy && "Unexpected type"); 477349cc55cSDimitry Andric 478349cc55cSDimitry Andric // Scale the stride by the size of the indexed type. 479349cc55cSDimitry Andric if (TypeScale != 1) 480349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 481349cc55cSDimitry Andric 48281ad6265SDimitry Andric auto P = std::make_pair(BasePtr, Stride); 48381ad6265SDimitry Andric StridedAddrs[GEP] = P; 48481ad6265SDimitry Andric return P; 485349cc55cSDimitry Andric } 486349cc55cSDimitry Andric 487349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 488349cc55cSDimitry Andric Type *DataType, 489349cc55cSDimitry Andric Value *Ptr, 490349cc55cSDimitry Andric Value *AlignOp) { 491349cc55cSDimitry Andric // Make sure the operation will be supported by the backend. 49206c3fb27SDimitry Andric MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 49306c3fb27SDimitry Andric EVT DataTypeVT = TLI->getValueType(*DL, DataType); 49406c3fb27SDimitry Andric if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) 49506c3fb27SDimitry Andric return false; 49606c3fb27SDimitry Andric 49706c3fb27SDimitry Andric // FIXME: Let the backend type legalize by splitting/widening? 49806c3fb27SDimitry Andric if (!TLI->isTypeLegal(DataTypeVT)) 499349cc55cSDimitry Andric return false; 500349cc55cSDimitry Andric 5015f757f3fSDimitry Andric // Pointer should be an instruction. 5025f757f3fSDimitry Andric auto *PtrI = dyn_cast<Instruction>(Ptr); 5035f757f3fSDimitry Andric if (!PtrI) 504349cc55cSDimitry Andric return false; 505349cc55cSDimitry Andric 5065f757f3fSDimitry Andric LLVMContext &Ctx = PtrI->getContext(); 50706c3fb27SDimitry Andric IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL); 5085f757f3fSDimitry Andric Builder.SetInsertPoint(PtrI); 509349cc55cSDimitry Andric 510349cc55cSDimitry Andric Value *BasePtr, *Stride; 5115f757f3fSDimitry Andric std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder); 512349cc55cSDimitry Andric if (!BasePtr) 513349cc55cSDimitry Andric return false; 514349cc55cSDimitry Andric assert(Stride != nullptr); 515349cc55cSDimitry Andric 516349cc55cSDimitry Andric Builder.SetInsertPoint(II); 517349cc55cSDimitry Andric 518349cc55cSDimitry Andric CallInst *Call; 519349cc55cSDimitry Andric if (II->getIntrinsicID() == Intrinsic::masked_gather) 520349cc55cSDimitry Andric Call = Builder.CreateIntrinsic( 521349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_load, 522349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()}, 523349cc55cSDimitry Andric {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 524349cc55cSDimitry Andric else 525349cc55cSDimitry Andric Call = Builder.CreateIntrinsic( 526349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_store, 527349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()}, 528349cc55cSDimitry Andric {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 529349cc55cSDimitry Andric 530349cc55cSDimitry Andric Call->takeName(II); 531349cc55cSDimitry Andric II->replaceAllUsesWith(Call); 532349cc55cSDimitry Andric II->eraseFromParent(); 533349cc55cSDimitry Andric 5345f757f3fSDimitry Andric if (PtrI->use_empty()) 5355f757f3fSDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(PtrI); 536349cc55cSDimitry Andric 537349cc55cSDimitry Andric return true; 538349cc55cSDimitry Andric } 539349cc55cSDimitry Andric 540349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 541349cc55cSDimitry Andric if (skipFunction(F)) 542349cc55cSDimitry Andric return false; 543349cc55cSDimitry Andric 544349cc55cSDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 545349cc55cSDimitry Andric auto &TM = TPC.getTM<RISCVTargetMachine>(); 546349cc55cSDimitry Andric ST = &TM.getSubtarget<RISCVSubtarget>(F); 547349cc55cSDimitry Andric if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 548349cc55cSDimitry Andric return false; 549349cc55cSDimitry Andric 550349cc55cSDimitry Andric TLI = ST->getTargetLowering(); 551*0fca6ea1SDimitry Andric DL = &F.getDataLayout(); 552349cc55cSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 553349cc55cSDimitry Andric 55481ad6265SDimitry Andric StridedAddrs.clear(); 55581ad6265SDimitry Andric 556349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers; 557349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters; 558349cc55cSDimitry Andric 559349cc55cSDimitry Andric bool Changed = false; 560349cc55cSDimitry Andric 561349cc55cSDimitry Andric for (BasicBlock &BB : F) { 562349cc55cSDimitry Andric for (Instruction &I : BB) { 563349cc55cSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 564bdd1243dSDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { 565349cc55cSDimitry Andric Gathers.push_back(II); 566bdd1243dSDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { 567349cc55cSDimitry Andric Scatters.push_back(II); 568349cc55cSDimitry Andric } 569349cc55cSDimitry Andric } 570349cc55cSDimitry Andric } 571349cc55cSDimitry Andric 572349cc55cSDimitry Andric // Rewrite gather/scatter to form strided load/store if possible. 573349cc55cSDimitry Andric for (auto *II : Gathers) 574349cc55cSDimitry Andric Changed |= tryCreateStridedLoadStore( 575349cc55cSDimitry Andric II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 576349cc55cSDimitry Andric for (auto *II : Scatters) 577349cc55cSDimitry Andric Changed |= 578349cc55cSDimitry Andric tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 579349cc55cSDimitry Andric II->getArgOperand(1), II->getArgOperand(2)); 580349cc55cSDimitry Andric 581349cc55cSDimitry Andric // Remove any dead phis. 582349cc55cSDimitry Andric while (!MaybeDeadPHIs.empty()) { 583349cc55cSDimitry Andric if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 584349cc55cSDimitry Andric RecursivelyDeleteDeadPHINode(Phi); 585349cc55cSDimitry Andric } 586349cc55cSDimitry Andric 587349cc55cSDimitry Andric return Changed; 588349cc55cSDimitry Andric } 589