1*349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2*349cc55cSDimitry Andric // 3*349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*349cc55cSDimitry Andric // 7*349cc55cSDimitry Andric //===----------------------------------------------------------------------===// 8*349cc55cSDimitry Andric // 9*349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to 10*349cc55cSDimitry Andric // RISCV intrinsics. 11*349cc55cSDimitry Andric // 12*349cc55cSDimitry Andric //===----------------------------------------------------------------------===// 13*349cc55cSDimitry Andric 14*349cc55cSDimitry Andric #include "RISCV.h" 15*349cc55cSDimitry Andric #include "RISCVTargetMachine.h" 16*349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h" 17*349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 18*349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h" 19*349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 20*349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h" 21*349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h" 22*349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h" 23*349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h" 24*349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 25*349cc55cSDimitry Andric 26*349cc55cSDimitry Andric using namespace llvm; 27*349cc55cSDimitry Andric 28*349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering" 29*349cc55cSDimitry Andric 30*349cc55cSDimitry Andric namespace { 31*349cc55cSDimitry Andric 32*349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass { 33*349cc55cSDimitry Andric const RISCVSubtarget *ST = nullptr; 34*349cc55cSDimitry Andric const RISCVTargetLowering *TLI = nullptr; 35*349cc55cSDimitry Andric LoopInfo *LI = nullptr; 36*349cc55cSDimitry Andric const DataLayout *DL = nullptr; 37*349cc55cSDimitry Andric 38*349cc55cSDimitry Andric SmallVector<WeakTrackingVH> MaybeDeadPHIs; 39*349cc55cSDimitry Andric 40*349cc55cSDimitry Andric public: 41*349cc55cSDimitry Andric static char ID; // Pass identification, replacement for typeid 42*349cc55cSDimitry Andric 43*349cc55cSDimitry Andric RISCVGatherScatterLowering() : FunctionPass(ID) {} 44*349cc55cSDimitry Andric 45*349cc55cSDimitry Andric bool runOnFunction(Function &F) override; 46*349cc55cSDimitry Andric 47*349cc55cSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 48*349cc55cSDimitry Andric AU.setPreservesCFG(); 49*349cc55cSDimitry Andric AU.addRequired<TargetPassConfig>(); 50*349cc55cSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 51*349cc55cSDimitry Andric } 52*349cc55cSDimitry Andric 53*349cc55cSDimitry Andric StringRef getPassName() const override { 54*349cc55cSDimitry Andric return "RISCV gather/scatter lowering"; 55*349cc55cSDimitry Andric } 56*349cc55cSDimitry Andric 57*349cc55cSDimitry Andric private: 58*349cc55cSDimitry Andric bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp); 59*349cc55cSDimitry Andric 60*349cc55cSDimitry Andric bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 61*349cc55cSDimitry Andric Value *AlignOp); 62*349cc55cSDimitry Andric 63*349cc55cSDimitry Andric std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP, 64*349cc55cSDimitry Andric IRBuilder<> &Builder); 65*349cc55cSDimitry Andric 66*349cc55cSDimitry Andric bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 67*349cc55cSDimitry Andric PHINode *&BasePtr, BinaryOperator *&Inc, 68*349cc55cSDimitry Andric IRBuilder<> &Builder); 69*349cc55cSDimitry Andric }; 70*349cc55cSDimitry Andric 71*349cc55cSDimitry Andric } // end anonymous namespace 72*349cc55cSDimitry Andric 73*349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0; 74*349cc55cSDimitry Andric 75*349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 76*349cc55cSDimitry Andric "RISCV gather/scatter lowering pass", false, false) 77*349cc55cSDimitry Andric 78*349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 79*349cc55cSDimitry Andric return new RISCVGatherScatterLowering(); 80*349cc55cSDimitry Andric } 81*349cc55cSDimitry Andric 82*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType, 83*349cc55cSDimitry Andric Value *AlignOp) { 84*349cc55cSDimitry Andric Type *ScalarType = DataType->getScalarType(); 85*349cc55cSDimitry Andric if (!TLI->isLegalElementTypeForRVV(ScalarType)) 86*349cc55cSDimitry Andric return false; 87*349cc55cSDimitry Andric 88*349cc55cSDimitry Andric MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 89*349cc55cSDimitry Andric if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize()) 90*349cc55cSDimitry Andric return false; 91*349cc55cSDimitry Andric 92*349cc55cSDimitry Andric // FIXME: Let the backend type legalize by splitting/widening? 93*349cc55cSDimitry Andric EVT DataVT = TLI->getValueType(*DL, DataType); 94*349cc55cSDimitry Andric if (!TLI->isTypeLegal(DataVT)) 95*349cc55cSDimitry Andric return false; 96*349cc55cSDimitry Andric 97*349cc55cSDimitry Andric return true; 98*349cc55cSDimitry Andric } 99*349cc55cSDimitry Andric 100*349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride? 101*349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 102*349cc55cSDimitry Andric unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 103*349cc55cSDimitry Andric 104*349cc55cSDimitry Andric // Check that the start value is a strided constant. 105*349cc55cSDimitry Andric auto *StartVal = 106*349cc55cSDimitry Andric dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 107*349cc55cSDimitry Andric if (!StartVal) 108*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 109*349cc55cSDimitry Andric APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 110*349cc55cSDimitry Andric ConstantInt *Prev = StartVal; 111*349cc55cSDimitry Andric for (unsigned i = 1; i != NumElts; ++i) { 112*349cc55cSDimitry Andric auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 113*349cc55cSDimitry Andric if (!C) 114*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 115*349cc55cSDimitry Andric 116*349cc55cSDimitry Andric APInt LocalStride = C->getValue() - Prev->getValue(); 117*349cc55cSDimitry Andric if (i == 1) 118*349cc55cSDimitry Andric StrideVal = LocalStride; 119*349cc55cSDimitry Andric else if (StrideVal != LocalStride) 120*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 121*349cc55cSDimitry Andric 122*349cc55cSDimitry Andric Prev = C; 123*349cc55cSDimitry Andric } 124*349cc55cSDimitry Andric 125*349cc55cSDimitry Andric Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 126*349cc55cSDimitry Andric 127*349cc55cSDimitry Andric return std::make_pair(StartVal, Stride); 128*349cc55cSDimitry Andric } 129*349cc55cSDimitry Andric 130*349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided 131*349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion. 132*349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the 133*349cc55cSDimitry Andric // arithmetic out of the loop. 134*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 135*349cc55cSDimitry Andric Value *&Stride, 136*349cc55cSDimitry Andric PHINode *&BasePtr, 137*349cc55cSDimitry Andric BinaryOperator *&Inc, 138*349cc55cSDimitry Andric IRBuilder<> &Builder) { 139*349cc55cSDimitry Andric // Our base case is a Phi. 140*349cc55cSDimitry Andric if (auto *Phi = dyn_cast<PHINode>(Index)) { 141*349cc55cSDimitry Andric // A phi node we want to perform this function on should be from the 142*349cc55cSDimitry Andric // loop header. 143*349cc55cSDimitry Andric if (Phi->getParent() != L->getHeader()) 144*349cc55cSDimitry Andric return false; 145*349cc55cSDimitry Andric 146*349cc55cSDimitry Andric Value *Step, *Start; 147*349cc55cSDimitry Andric if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 148*349cc55cSDimitry Andric Inc->getOpcode() != Instruction::Add) 149*349cc55cSDimitry Andric return false; 150*349cc55cSDimitry Andric assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 151*349cc55cSDimitry Andric unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 152*349cc55cSDimitry Andric assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 153*349cc55cSDimitry Andric "Expected one operand of phi to be Inc"); 154*349cc55cSDimitry Andric 155*349cc55cSDimitry Andric // Only proceed if the step is loop invariant. 156*349cc55cSDimitry Andric if (!L->isLoopInvariant(Step)) 157*349cc55cSDimitry Andric return false; 158*349cc55cSDimitry Andric 159*349cc55cSDimitry Andric // Step should be a splat. 160*349cc55cSDimitry Andric Step = getSplatValue(Step); 161*349cc55cSDimitry Andric if (!Step) 162*349cc55cSDimitry Andric return false; 163*349cc55cSDimitry Andric 164*349cc55cSDimitry Andric // Start should be a strided constant. 165*349cc55cSDimitry Andric auto *StartC = dyn_cast<Constant>(Start); 166*349cc55cSDimitry Andric if (!StartC) 167*349cc55cSDimitry Andric return false; 168*349cc55cSDimitry Andric 169*349cc55cSDimitry Andric std::tie(Start, Stride) = matchStridedConstant(StartC); 170*349cc55cSDimitry Andric if (!Start) 171*349cc55cSDimitry Andric return false; 172*349cc55cSDimitry Andric assert(Stride != nullptr); 173*349cc55cSDimitry Andric 174*349cc55cSDimitry Andric // Build scalar phi and increment. 175*349cc55cSDimitry Andric BasePtr = 176*349cc55cSDimitry Andric PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); 177*349cc55cSDimitry Andric Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 178*349cc55cSDimitry Andric Inc); 179*349cc55cSDimitry Andric BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 180*349cc55cSDimitry Andric BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 181*349cc55cSDimitry Andric 182*349cc55cSDimitry Andric // Note that this Phi might be eligible for removal. 183*349cc55cSDimitry Andric MaybeDeadPHIs.push_back(Phi); 184*349cc55cSDimitry Andric return true; 185*349cc55cSDimitry Andric } 186*349cc55cSDimitry Andric 187*349cc55cSDimitry Andric // Otherwise look for binary operator. 188*349cc55cSDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Index); 189*349cc55cSDimitry Andric if (!BO) 190*349cc55cSDimitry Andric return false; 191*349cc55cSDimitry Andric 192*349cc55cSDimitry Andric if (BO->getOpcode() != Instruction::Add && 193*349cc55cSDimitry Andric BO->getOpcode() != Instruction::Or && 194*349cc55cSDimitry Andric BO->getOpcode() != Instruction::Mul && 195*349cc55cSDimitry Andric BO->getOpcode() != Instruction::Shl) 196*349cc55cSDimitry Andric return false; 197*349cc55cSDimitry Andric 198*349cc55cSDimitry Andric // Only support shift by constant. 199*349cc55cSDimitry Andric if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1))) 200*349cc55cSDimitry Andric return false; 201*349cc55cSDimitry Andric 202*349cc55cSDimitry Andric // We need to be able to treat Or as Add. 203*349cc55cSDimitry Andric if (BO->getOpcode() == Instruction::Or && 204*349cc55cSDimitry Andric !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) 205*349cc55cSDimitry Andric return false; 206*349cc55cSDimitry Andric 207*349cc55cSDimitry Andric // We should have one operand in the loop and one splat. 208*349cc55cSDimitry Andric Value *OtherOp; 209*349cc55cSDimitry Andric if (isa<Instruction>(BO->getOperand(0)) && 210*349cc55cSDimitry Andric L->contains(cast<Instruction>(BO->getOperand(0)))) { 211*349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(0)); 212*349cc55cSDimitry Andric OtherOp = BO->getOperand(1); 213*349cc55cSDimitry Andric } else if (isa<Instruction>(BO->getOperand(1)) && 214*349cc55cSDimitry Andric L->contains(cast<Instruction>(BO->getOperand(1)))) { 215*349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(1)); 216*349cc55cSDimitry Andric OtherOp = BO->getOperand(0); 217*349cc55cSDimitry Andric } else { 218*349cc55cSDimitry Andric return false; 219*349cc55cSDimitry Andric } 220*349cc55cSDimitry Andric 221*349cc55cSDimitry Andric // Make sure other op is loop invariant. 222*349cc55cSDimitry Andric if (!L->isLoopInvariant(OtherOp)) 223*349cc55cSDimitry Andric return false; 224*349cc55cSDimitry Andric 225*349cc55cSDimitry Andric // Make sure we have a splat. 226*349cc55cSDimitry Andric Value *SplatOp = getSplatValue(OtherOp); 227*349cc55cSDimitry Andric if (!SplatOp) 228*349cc55cSDimitry Andric return false; 229*349cc55cSDimitry Andric 230*349cc55cSDimitry Andric // Recurse up the use-def chain. 231*349cc55cSDimitry Andric if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 232*349cc55cSDimitry Andric return false; 233*349cc55cSDimitry Andric 234*349cc55cSDimitry Andric // Locate the Step and Start values from the recurrence. 235*349cc55cSDimitry Andric unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 236*349cc55cSDimitry Andric unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 237*349cc55cSDimitry Andric Value *Step = Inc->getOperand(StepIndex); 238*349cc55cSDimitry Andric Value *Start = BasePtr->getOperand(StartBlock); 239*349cc55cSDimitry Andric 240*349cc55cSDimitry Andric // We need to adjust the start value in the preheader. 241*349cc55cSDimitry Andric Builder.SetInsertPoint( 242*349cc55cSDimitry Andric BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 243*349cc55cSDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc()); 244*349cc55cSDimitry Andric 245*349cc55cSDimitry Andric switch (BO->getOpcode()) { 246*349cc55cSDimitry Andric default: 247*349cc55cSDimitry Andric llvm_unreachable("Unexpected opcode!"); 248*349cc55cSDimitry Andric case Instruction::Add: 249*349cc55cSDimitry Andric case Instruction::Or: { 250*349cc55cSDimitry Andric // An add only affects the start value. It's ok to do this for Or because 251*349cc55cSDimitry Andric // we already checked that there are no common set bits. 252*349cc55cSDimitry Andric 253*349cc55cSDimitry Andric // If the start value is Zero, just take the SplatOp. 254*349cc55cSDimitry Andric if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero()) 255*349cc55cSDimitry Andric Start = SplatOp; 256*349cc55cSDimitry Andric else 257*349cc55cSDimitry Andric Start = Builder.CreateAdd(Start, SplatOp, "start"); 258*349cc55cSDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 259*349cc55cSDimitry Andric break; 260*349cc55cSDimitry Andric } 261*349cc55cSDimitry Andric case Instruction::Mul: { 262*349cc55cSDimitry Andric // If the start is zero we don't need to multiply. 263*349cc55cSDimitry Andric if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 264*349cc55cSDimitry Andric Start = Builder.CreateMul(Start, SplatOp, "start"); 265*349cc55cSDimitry Andric 266*349cc55cSDimitry Andric Step = Builder.CreateMul(Step, SplatOp, "step"); 267*349cc55cSDimitry Andric 268*349cc55cSDimitry Andric // If the Stride is 1 just take the SplatOpt. 269*349cc55cSDimitry Andric if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne()) 270*349cc55cSDimitry Andric Stride = SplatOp; 271*349cc55cSDimitry Andric else 272*349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 273*349cc55cSDimitry Andric Inc->setOperand(StepIndex, Step); 274*349cc55cSDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 275*349cc55cSDimitry Andric break; 276*349cc55cSDimitry Andric } 277*349cc55cSDimitry Andric case Instruction::Shl: { 278*349cc55cSDimitry Andric // If the start is zero we don't need to shift. 279*349cc55cSDimitry Andric if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 280*349cc55cSDimitry Andric Start = Builder.CreateShl(Start, SplatOp, "start"); 281*349cc55cSDimitry Andric Step = Builder.CreateShl(Step, SplatOp, "step"); 282*349cc55cSDimitry Andric Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 283*349cc55cSDimitry Andric Inc->setOperand(StepIndex, Step); 284*349cc55cSDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 285*349cc55cSDimitry Andric break; 286*349cc55cSDimitry Andric } 287*349cc55cSDimitry Andric } 288*349cc55cSDimitry Andric 289*349cc55cSDimitry Andric return true; 290*349cc55cSDimitry Andric } 291*349cc55cSDimitry Andric 292*349cc55cSDimitry Andric std::pair<Value *, Value *> 293*349cc55cSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, 294*349cc55cSDimitry Andric IRBuilder<> &Builder) { 295*349cc55cSDimitry Andric 296*349cc55cSDimitry Andric SmallVector<Value *, 2> Ops(GEP->operands()); 297*349cc55cSDimitry Andric 298*349cc55cSDimitry Andric // Base pointer needs to be a scalar. 299*349cc55cSDimitry Andric if (Ops[0]->getType()->isVectorTy()) 300*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 301*349cc55cSDimitry Andric 302*349cc55cSDimitry Andric // Make sure we're in a loop and it is in loop simplify form. 303*349cc55cSDimitry Andric Loop *L = LI->getLoopFor(GEP->getParent()); 304*349cc55cSDimitry Andric if (!L || !L->isLoopSimplifyForm()) 305*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 306*349cc55cSDimitry Andric 307*349cc55cSDimitry Andric Optional<unsigned> VecOperand; 308*349cc55cSDimitry Andric unsigned TypeScale = 0; 309*349cc55cSDimitry Andric 310*349cc55cSDimitry Andric // Look for a vector operand and scale. 311*349cc55cSDimitry Andric gep_type_iterator GTI = gep_type_begin(GEP); 312*349cc55cSDimitry Andric for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 313*349cc55cSDimitry Andric if (!Ops[i]->getType()->isVectorTy()) 314*349cc55cSDimitry Andric continue; 315*349cc55cSDimitry Andric 316*349cc55cSDimitry Andric if (VecOperand) 317*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 318*349cc55cSDimitry Andric 319*349cc55cSDimitry Andric VecOperand = i; 320*349cc55cSDimitry Andric 321*349cc55cSDimitry Andric TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType()); 322*349cc55cSDimitry Andric if (TS.isScalable()) 323*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 324*349cc55cSDimitry Andric 325*349cc55cSDimitry Andric TypeScale = TS.getFixedSize(); 326*349cc55cSDimitry Andric } 327*349cc55cSDimitry Andric 328*349cc55cSDimitry Andric // We need to find a vector index to simplify. 329*349cc55cSDimitry Andric if (!VecOperand) 330*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 331*349cc55cSDimitry Andric 332*349cc55cSDimitry Andric // We can't extract the stride if the arithmetic is done at a different size 333*349cc55cSDimitry Andric // than the pointer type. Adding the stride later may not wrap correctly. 334*349cc55cSDimitry Andric // Technically we could handle wider indices, but I don't expect that in 335*349cc55cSDimitry Andric // practice. 336*349cc55cSDimitry Andric Value *VecIndex = Ops[*VecOperand]; 337*349cc55cSDimitry Andric Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 338*349cc55cSDimitry Andric if (VecIndex->getType() != VecIntPtrTy) 339*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 340*349cc55cSDimitry Andric 341*349cc55cSDimitry Andric Value *Stride; 342*349cc55cSDimitry Andric BinaryOperator *Inc; 343*349cc55cSDimitry Andric PHINode *BasePhi; 344*349cc55cSDimitry Andric if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 345*349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 346*349cc55cSDimitry Andric 347*349cc55cSDimitry Andric assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 348*349cc55cSDimitry Andric unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 349*349cc55cSDimitry Andric assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 350*349cc55cSDimitry Andric "Expected one operand of phi to be Inc"); 351*349cc55cSDimitry Andric 352*349cc55cSDimitry Andric Builder.SetInsertPoint(GEP); 353*349cc55cSDimitry Andric 354*349cc55cSDimitry Andric // Replace the vector index with the scalar phi and build a scalar GEP. 355*349cc55cSDimitry Andric Ops[*VecOperand] = BasePhi; 356*349cc55cSDimitry Andric Type *SourceTy = GEP->getSourceElementType(); 357*349cc55cSDimitry Andric Value *BasePtr = 358*349cc55cSDimitry Andric Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); 359*349cc55cSDimitry Andric 360*349cc55cSDimitry Andric // Cast the GEP to an i8*. 361*349cc55cSDimitry Andric LLVMContext &Ctx = GEP->getContext(); 362*349cc55cSDimitry Andric Type *I8PtrTy = 363*349cc55cSDimitry Andric Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace()); 364*349cc55cSDimitry Andric if (BasePtr->getType() != I8PtrTy) 365*349cc55cSDimitry Andric BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy); 366*349cc55cSDimitry Andric 367*349cc55cSDimitry Andric // Final adjustments to stride should go in the start block. 368*349cc55cSDimitry Andric Builder.SetInsertPoint( 369*349cc55cSDimitry Andric BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 370*349cc55cSDimitry Andric 371*349cc55cSDimitry Andric // Convert stride to pointer size if needed. 372*349cc55cSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 373*349cc55cSDimitry Andric assert(Stride->getType() == IntPtrTy && "Unexpected type"); 374*349cc55cSDimitry Andric 375*349cc55cSDimitry Andric // Scale the stride by the size of the indexed type. 376*349cc55cSDimitry Andric if (TypeScale != 1) 377*349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 378*349cc55cSDimitry Andric 379*349cc55cSDimitry Andric return std::make_pair(BasePtr, Stride); 380*349cc55cSDimitry Andric } 381*349cc55cSDimitry Andric 382*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 383*349cc55cSDimitry Andric Type *DataType, 384*349cc55cSDimitry Andric Value *Ptr, 385*349cc55cSDimitry Andric Value *AlignOp) { 386*349cc55cSDimitry Andric // Make sure the operation will be supported by the backend. 387*349cc55cSDimitry Andric if (!isLegalTypeAndAlignment(DataType, AlignOp)) 388*349cc55cSDimitry Andric return false; 389*349cc55cSDimitry Andric 390*349cc55cSDimitry Andric // Pointer should be a GEP. 391*349cc55cSDimitry Andric auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 392*349cc55cSDimitry Andric if (!GEP) 393*349cc55cSDimitry Andric return false; 394*349cc55cSDimitry Andric 395*349cc55cSDimitry Andric IRBuilder<> Builder(GEP); 396*349cc55cSDimitry Andric 397*349cc55cSDimitry Andric Value *BasePtr, *Stride; 398*349cc55cSDimitry Andric std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); 399*349cc55cSDimitry Andric if (!BasePtr) 400*349cc55cSDimitry Andric return false; 401*349cc55cSDimitry Andric assert(Stride != nullptr); 402*349cc55cSDimitry Andric 403*349cc55cSDimitry Andric Builder.SetInsertPoint(II); 404*349cc55cSDimitry Andric 405*349cc55cSDimitry Andric CallInst *Call; 406*349cc55cSDimitry Andric if (II->getIntrinsicID() == Intrinsic::masked_gather) 407*349cc55cSDimitry Andric Call = Builder.CreateIntrinsic( 408*349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_load, 409*349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()}, 410*349cc55cSDimitry Andric {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 411*349cc55cSDimitry Andric else 412*349cc55cSDimitry Andric Call = Builder.CreateIntrinsic( 413*349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_store, 414*349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()}, 415*349cc55cSDimitry Andric {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 416*349cc55cSDimitry Andric 417*349cc55cSDimitry Andric Call->takeName(II); 418*349cc55cSDimitry Andric II->replaceAllUsesWith(Call); 419*349cc55cSDimitry Andric II->eraseFromParent(); 420*349cc55cSDimitry Andric 421*349cc55cSDimitry Andric if (GEP->use_empty()) 422*349cc55cSDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(GEP); 423*349cc55cSDimitry Andric 424*349cc55cSDimitry Andric return true; 425*349cc55cSDimitry Andric } 426*349cc55cSDimitry Andric 427*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 428*349cc55cSDimitry Andric if (skipFunction(F)) 429*349cc55cSDimitry Andric return false; 430*349cc55cSDimitry Andric 431*349cc55cSDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 432*349cc55cSDimitry Andric auto &TM = TPC.getTM<RISCVTargetMachine>(); 433*349cc55cSDimitry Andric ST = &TM.getSubtarget<RISCVSubtarget>(F); 434*349cc55cSDimitry Andric if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 435*349cc55cSDimitry Andric return false; 436*349cc55cSDimitry Andric 437*349cc55cSDimitry Andric TLI = ST->getTargetLowering(); 438*349cc55cSDimitry Andric DL = &F.getParent()->getDataLayout(); 439*349cc55cSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 440*349cc55cSDimitry Andric 441*349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers; 442*349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters; 443*349cc55cSDimitry Andric 444*349cc55cSDimitry Andric bool Changed = false; 445*349cc55cSDimitry Andric 446*349cc55cSDimitry Andric for (BasicBlock &BB : F) { 447*349cc55cSDimitry Andric for (Instruction &I : BB) { 448*349cc55cSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 449*349cc55cSDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 450*349cc55cSDimitry Andric isa<FixedVectorType>(II->getType())) { 451*349cc55cSDimitry Andric Gathers.push_back(II); 452*349cc55cSDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 453*349cc55cSDimitry Andric isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 454*349cc55cSDimitry Andric Scatters.push_back(II); 455*349cc55cSDimitry Andric } 456*349cc55cSDimitry Andric } 457*349cc55cSDimitry Andric } 458*349cc55cSDimitry Andric 459*349cc55cSDimitry Andric // Rewrite gather/scatter to form strided load/store if possible. 460*349cc55cSDimitry Andric for (auto *II : Gathers) 461*349cc55cSDimitry Andric Changed |= tryCreateStridedLoadStore( 462*349cc55cSDimitry Andric II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 463*349cc55cSDimitry Andric for (auto *II : Scatters) 464*349cc55cSDimitry Andric Changed |= 465*349cc55cSDimitry Andric tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 466*349cc55cSDimitry Andric II->getArgOperand(1), II->getArgOperand(2)); 467*349cc55cSDimitry Andric 468*349cc55cSDimitry Andric // Remove any dead phis. 469*349cc55cSDimitry Andric while (!MaybeDeadPHIs.empty()) { 470*349cc55cSDimitry Andric if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 471*349cc55cSDimitry Andric RecursivelyDeleteDeadPHINode(Phi); 472*349cc55cSDimitry Andric } 473*349cc55cSDimitry Andric 474*349cc55cSDimitry Andric return Changed; 475*349cc55cSDimitry Andric } 476