1349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2349cc55cSDimitry Andric // 3349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6349cc55cSDimitry Andric // 7349cc55cSDimitry Andric //===----------------------------------------------------------------------===// 8349cc55cSDimitry Andric // 9349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to 10349cc55cSDimitry Andric // RISCV intrinsics. 11349cc55cSDimitry Andric // 12349cc55cSDimitry Andric //===----------------------------------------------------------------------===// 13349cc55cSDimitry Andric 14349cc55cSDimitry Andric #include "RISCV.h" 15349cc55cSDimitry Andric #include "RISCVTargetMachine.h" 16349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h" 17349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 18349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h" 19349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 20349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h" 21349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h" 22349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h" 23349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h" 24349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 25349cc55cSDimitry Andric 26349cc55cSDimitry Andric using namespace llvm; 27349cc55cSDimitry Andric 28349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering" 29349cc55cSDimitry Andric 30349cc55cSDimitry Andric namespace { 31349cc55cSDimitry Andric 32349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass { 33349cc55cSDimitry Andric const RISCVSubtarget *ST = nullptr; 34349cc55cSDimitry Andric const RISCVTargetLowering *TLI = nullptr; 35349cc55cSDimitry Andric LoopInfo *LI = nullptr; 36349cc55cSDimitry Andric const DataLayout *DL = nullptr; 37349cc55cSDimitry Andric 38349cc55cSDimitry Andric SmallVector<WeakTrackingVH> MaybeDeadPHIs; 39349cc55cSDimitry Andric 40349cc55cSDimitry Andric public: 41349cc55cSDimitry Andric static char ID; // Pass identification, replacement for typeid 42349cc55cSDimitry Andric 43349cc55cSDimitry Andric RISCVGatherScatterLowering() : FunctionPass(ID) {} 44349cc55cSDimitry Andric 45349cc55cSDimitry Andric bool runOnFunction(Function &F) override; 46349cc55cSDimitry Andric 47349cc55cSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 48349cc55cSDimitry Andric AU.setPreservesCFG(); 49349cc55cSDimitry Andric AU.addRequired<TargetPassConfig>(); 50349cc55cSDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 51349cc55cSDimitry Andric } 52349cc55cSDimitry Andric 53349cc55cSDimitry Andric StringRef getPassName() const override { 54349cc55cSDimitry Andric return "RISCV gather/scatter lowering"; 55349cc55cSDimitry Andric } 56349cc55cSDimitry Andric 57349cc55cSDimitry Andric private: 58349cc55cSDimitry Andric bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp); 59349cc55cSDimitry Andric 60349cc55cSDimitry Andric bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 61349cc55cSDimitry Andric Value *AlignOp); 62349cc55cSDimitry Andric 63349cc55cSDimitry Andric std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP, 64349cc55cSDimitry Andric IRBuilder<> &Builder); 65349cc55cSDimitry Andric 66349cc55cSDimitry Andric bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 67349cc55cSDimitry Andric PHINode *&BasePtr, BinaryOperator *&Inc, 68349cc55cSDimitry Andric IRBuilder<> &Builder); 69349cc55cSDimitry Andric }; 70349cc55cSDimitry Andric 71349cc55cSDimitry Andric } // end anonymous namespace 72349cc55cSDimitry Andric 73349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0; 74349cc55cSDimitry Andric 75349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 76349cc55cSDimitry Andric "RISCV gather/scatter lowering pass", false, false) 77349cc55cSDimitry Andric 78349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 79349cc55cSDimitry Andric return new RISCVGatherScatterLowering(); 80349cc55cSDimitry Andric } 81349cc55cSDimitry Andric 82349cc55cSDimitry Andric bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType, 83349cc55cSDimitry Andric Value *AlignOp) { 84349cc55cSDimitry Andric Type *ScalarType = DataType->getScalarType(); 85349cc55cSDimitry Andric if (!TLI->isLegalElementTypeForRVV(ScalarType)) 86349cc55cSDimitry Andric return false; 87349cc55cSDimitry Andric 88349cc55cSDimitry Andric MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 89349cc55cSDimitry Andric if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize()) 90349cc55cSDimitry Andric return false; 91349cc55cSDimitry Andric 92349cc55cSDimitry Andric // FIXME: Let the backend type legalize by splitting/widening? 93349cc55cSDimitry Andric EVT DataVT = TLI->getValueType(*DL, DataType); 94349cc55cSDimitry Andric if (!TLI->isTypeLegal(DataVT)) 95349cc55cSDimitry Andric return false; 96349cc55cSDimitry Andric 97349cc55cSDimitry Andric return true; 98349cc55cSDimitry Andric } 99349cc55cSDimitry Andric 100349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride? 101349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 102349cc55cSDimitry Andric unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 103349cc55cSDimitry Andric 104349cc55cSDimitry Andric // Check that the start value is a strided constant. 105349cc55cSDimitry Andric auto *StartVal = 106349cc55cSDimitry Andric dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 107349cc55cSDimitry Andric if (!StartVal) 108349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 109349cc55cSDimitry Andric APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 110349cc55cSDimitry Andric ConstantInt *Prev = StartVal; 111349cc55cSDimitry Andric for (unsigned i = 1; i != NumElts; ++i) { 112349cc55cSDimitry Andric auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 113349cc55cSDimitry Andric if (!C) 114349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 115349cc55cSDimitry Andric 116349cc55cSDimitry Andric APInt LocalStride = C->getValue() - Prev->getValue(); 117349cc55cSDimitry Andric if (i == 1) 118349cc55cSDimitry Andric StrideVal = LocalStride; 119349cc55cSDimitry Andric else if (StrideVal != LocalStride) 120349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 121349cc55cSDimitry Andric 122349cc55cSDimitry Andric Prev = C; 123349cc55cSDimitry Andric } 124349cc55cSDimitry Andric 125349cc55cSDimitry Andric Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 126349cc55cSDimitry Andric 127349cc55cSDimitry Andric return std::make_pair(StartVal, Stride); 128349cc55cSDimitry Andric } 129349cc55cSDimitry Andric 130*04eeddc0SDimitry Andric static std::pair<Value *, Value *> matchStridedStart(Value *Start, 131*04eeddc0SDimitry Andric IRBuilder<> &Builder) { 132*04eeddc0SDimitry Andric // Base case, start is a strided constant. 133*04eeddc0SDimitry Andric auto *StartC = dyn_cast<Constant>(Start); 134*04eeddc0SDimitry Andric if (StartC) 135*04eeddc0SDimitry Andric return matchStridedConstant(StartC); 136*04eeddc0SDimitry Andric 137*04eeddc0SDimitry Andric // Not a constant, maybe it's a strided constant with a splat added to it. 138*04eeddc0SDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Start); 139*04eeddc0SDimitry Andric if (!BO || BO->getOpcode() != Instruction::Add) 140*04eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr); 141*04eeddc0SDimitry Andric 142*04eeddc0SDimitry Andric // Look for an operand that is splatted. 143*04eeddc0SDimitry Andric unsigned OtherIndex = 1; 144*04eeddc0SDimitry Andric Value *Splat = getSplatValue(BO->getOperand(0)); 145*04eeddc0SDimitry Andric if (!Splat) { 146*04eeddc0SDimitry Andric Splat = getSplatValue(BO->getOperand(1)); 147*04eeddc0SDimitry Andric OtherIndex = 0; 148*04eeddc0SDimitry Andric } 149*04eeddc0SDimitry Andric if (!Splat) 150*04eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr); 151*04eeddc0SDimitry Andric 152*04eeddc0SDimitry Andric Value *Stride; 153*04eeddc0SDimitry Andric std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 154*04eeddc0SDimitry Andric Builder); 155*04eeddc0SDimitry Andric if (!Start) 156*04eeddc0SDimitry Andric return std::make_pair(nullptr, nullptr); 157*04eeddc0SDimitry Andric 158*04eeddc0SDimitry Andric // Add the splat value to the start. 159*04eeddc0SDimitry Andric Builder.SetInsertPoint(BO); 160*04eeddc0SDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc()); 161*04eeddc0SDimitry Andric Start = Builder.CreateAdd(Start, Splat); 162*04eeddc0SDimitry Andric return std::make_pair(Start, Stride); 163*04eeddc0SDimitry Andric } 164*04eeddc0SDimitry Andric 165349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided 166349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion. 167349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the 168349cc55cSDimitry Andric // arithmetic out of the loop. 169349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 170349cc55cSDimitry Andric Value *&Stride, 171349cc55cSDimitry Andric PHINode *&BasePtr, 172349cc55cSDimitry Andric BinaryOperator *&Inc, 173349cc55cSDimitry Andric IRBuilder<> &Builder) { 174349cc55cSDimitry Andric // Our base case is a Phi. 175349cc55cSDimitry Andric if (auto *Phi = dyn_cast<PHINode>(Index)) { 176349cc55cSDimitry Andric // A phi node we want to perform this function on should be from the 177349cc55cSDimitry Andric // loop header. 178349cc55cSDimitry Andric if (Phi->getParent() != L->getHeader()) 179349cc55cSDimitry Andric return false; 180349cc55cSDimitry Andric 181349cc55cSDimitry Andric Value *Step, *Start; 182349cc55cSDimitry Andric if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 183349cc55cSDimitry Andric Inc->getOpcode() != Instruction::Add) 184349cc55cSDimitry Andric return false; 185349cc55cSDimitry Andric assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 186349cc55cSDimitry Andric unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 187349cc55cSDimitry Andric assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 188349cc55cSDimitry Andric "Expected one operand of phi to be Inc"); 189349cc55cSDimitry Andric 190349cc55cSDimitry Andric // Only proceed if the step is loop invariant. 191349cc55cSDimitry Andric if (!L->isLoopInvariant(Step)) 192349cc55cSDimitry Andric return false; 193349cc55cSDimitry Andric 194349cc55cSDimitry Andric // Step should be a splat. 195349cc55cSDimitry Andric Step = getSplatValue(Step); 196349cc55cSDimitry Andric if (!Step) 197349cc55cSDimitry Andric return false; 198349cc55cSDimitry Andric 199*04eeddc0SDimitry Andric std::tie(Start, Stride) = matchStridedStart(Start, Builder); 200349cc55cSDimitry Andric if (!Start) 201349cc55cSDimitry Andric return false; 202349cc55cSDimitry Andric assert(Stride != nullptr); 203349cc55cSDimitry Andric 204349cc55cSDimitry Andric // Build scalar phi and increment. 205349cc55cSDimitry Andric BasePtr = 206349cc55cSDimitry Andric PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); 207349cc55cSDimitry Andric Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 208349cc55cSDimitry Andric Inc); 209349cc55cSDimitry Andric BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 210349cc55cSDimitry Andric BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 211349cc55cSDimitry Andric 212349cc55cSDimitry Andric // Note that this Phi might be eligible for removal. 213349cc55cSDimitry Andric MaybeDeadPHIs.push_back(Phi); 214349cc55cSDimitry Andric return true; 215349cc55cSDimitry Andric } 216349cc55cSDimitry Andric 217349cc55cSDimitry Andric // Otherwise look for binary operator. 218349cc55cSDimitry Andric auto *BO = dyn_cast<BinaryOperator>(Index); 219349cc55cSDimitry Andric if (!BO) 220349cc55cSDimitry Andric return false; 221349cc55cSDimitry Andric 222349cc55cSDimitry Andric if (BO->getOpcode() != Instruction::Add && 223349cc55cSDimitry Andric BO->getOpcode() != Instruction::Or && 224349cc55cSDimitry Andric BO->getOpcode() != Instruction::Mul && 225349cc55cSDimitry Andric BO->getOpcode() != Instruction::Shl) 226349cc55cSDimitry Andric return false; 227349cc55cSDimitry Andric 228349cc55cSDimitry Andric // Only support shift by constant. 229349cc55cSDimitry Andric if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1))) 230349cc55cSDimitry Andric return false; 231349cc55cSDimitry Andric 232349cc55cSDimitry Andric // We need to be able to treat Or as Add. 233349cc55cSDimitry Andric if (BO->getOpcode() == Instruction::Or && 234349cc55cSDimitry Andric !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) 235349cc55cSDimitry Andric return false; 236349cc55cSDimitry Andric 237349cc55cSDimitry Andric // We should have one operand in the loop and one splat. 238349cc55cSDimitry Andric Value *OtherOp; 239349cc55cSDimitry Andric if (isa<Instruction>(BO->getOperand(0)) && 240349cc55cSDimitry Andric L->contains(cast<Instruction>(BO->getOperand(0)))) { 241349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(0)); 242349cc55cSDimitry Andric OtherOp = BO->getOperand(1); 243349cc55cSDimitry Andric } else if (isa<Instruction>(BO->getOperand(1)) && 244349cc55cSDimitry Andric L->contains(cast<Instruction>(BO->getOperand(1)))) { 245349cc55cSDimitry Andric Index = cast<Instruction>(BO->getOperand(1)); 246349cc55cSDimitry Andric OtherOp = BO->getOperand(0); 247349cc55cSDimitry Andric } else { 248349cc55cSDimitry Andric return false; 249349cc55cSDimitry Andric } 250349cc55cSDimitry Andric 251349cc55cSDimitry Andric // Make sure other op is loop invariant. 252349cc55cSDimitry Andric if (!L->isLoopInvariant(OtherOp)) 253349cc55cSDimitry Andric return false; 254349cc55cSDimitry Andric 255349cc55cSDimitry Andric // Make sure we have a splat. 256349cc55cSDimitry Andric Value *SplatOp = getSplatValue(OtherOp); 257349cc55cSDimitry Andric if (!SplatOp) 258349cc55cSDimitry Andric return false; 259349cc55cSDimitry Andric 260349cc55cSDimitry Andric // Recurse up the use-def chain. 261349cc55cSDimitry Andric if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 262349cc55cSDimitry Andric return false; 263349cc55cSDimitry Andric 264349cc55cSDimitry Andric // Locate the Step and Start values from the recurrence. 265349cc55cSDimitry Andric unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 266349cc55cSDimitry Andric unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 267349cc55cSDimitry Andric Value *Step = Inc->getOperand(StepIndex); 268349cc55cSDimitry Andric Value *Start = BasePtr->getOperand(StartBlock); 269349cc55cSDimitry Andric 270349cc55cSDimitry Andric // We need to adjust the start value in the preheader. 271349cc55cSDimitry Andric Builder.SetInsertPoint( 272349cc55cSDimitry Andric BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 273349cc55cSDimitry Andric Builder.SetCurrentDebugLocation(DebugLoc()); 274349cc55cSDimitry Andric 275349cc55cSDimitry Andric switch (BO->getOpcode()) { 276349cc55cSDimitry Andric default: 277349cc55cSDimitry Andric llvm_unreachable("Unexpected opcode!"); 278349cc55cSDimitry Andric case Instruction::Add: 279349cc55cSDimitry Andric case Instruction::Or: { 280349cc55cSDimitry Andric // An add only affects the start value. It's ok to do this for Or because 281349cc55cSDimitry Andric // we already checked that there are no common set bits. 282349cc55cSDimitry Andric 283349cc55cSDimitry Andric // If the start value is Zero, just take the SplatOp. 284349cc55cSDimitry Andric if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero()) 285349cc55cSDimitry Andric Start = SplatOp; 286349cc55cSDimitry Andric else 287349cc55cSDimitry Andric Start = Builder.CreateAdd(Start, SplatOp, "start"); 288349cc55cSDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 289349cc55cSDimitry Andric break; 290349cc55cSDimitry Andric } 291349cc55cSDimitry Andric case Instruction::Mul: { 292349cc55cSDimitry Andric // If the start is zero we don't need to multiply. 293349cc55cSDimitry Andric if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 294349cc55cSDimitry Andric Start = Builder.CreateMul(Start, SplatOp, "start"); 295349cc55cSDimitry Andric 296349cc55cSDimitry Andric Step = Builder.CreateMul(Step, SplatOp, "step"); 297349cc55cSDimitry Andric 298349cc55cSDimitry Andric // If the Stride is 1 just take the SplatOpt. 299349cc55cSDimitry Andric if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne()) 300349cc55cSDimitry Andric Stride = SplatOp; 301349cc55cSDimitry Andric else 302349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 303349cc55cSDimitry Andric Inc->setOperand(StepIndex, Step); 304349cc55cSDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 305349cc55cSDimitry Andric break; 306349cc55cSDimitry Andric } 307349cc55cSDimitry Andric case Instruction::Shl: { 308349cc55cSDimitry Andric // If the start is zero we don't need to shift. 309349cc55cSDimitry Andric if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 310349cc55cSDimitry Andric Start = Builder.CreateShl(Start, SplatOp, "start"); 311349cc55cSDimitry Andric Step = Builder.CreateShl(Step, SplatOp, "step"); 312349cc55cSDimitry Andric Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 313349cc55cSDimitry Andric Inc->setOperand(StepIndex, Step); 314349cc55cSDimitry Andric BasePtr->setIncomingValue(StartBlock, Start); 315349cc55cSDimitry Andric break; 316349cc55cSDimitry Andric } 317349cc55cSDimitry Andric } 318349cc55cSDimitry Andric 319349cc55cSDimitry Andric return true; 320349cc55cSDimitry Andric } 321349cc55cSDimitry Andric 322349cc55cSDimitry Andric std::pair<Value *, Value *> 323349cc55cSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, 324349cc55cSDimitry Andric IRBuilder<> &Builder) { 325349cc55cSDimitry Andric 326349cc55cSDimitry Andric SmallVector<Value *, 2> Ops(GEP->operands()); 327349cc55cSDimitry Andric 328349cc55cSDimitry Andric // Base pointer needs to be a scalar. 329349cc55cSDimitry Andric if (Ops[0]->getType()->isVectorTy()) 330349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 331349cc55cSDimitry Andric 332349cc55cSDimitry Andric // Make sure we're in a loop and it is in loop simplify form. 333349cc55cSDimitry Andric Loop *L = LI->getLoopFor(GEP->getParent()); 334349cc55cSDimitry Andric if (!L || !L->isLoopSimplifyForm()) 335349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 336349cc55cSDimitry Andric 337349cc55cSDimitry Andric Optional<unsigned> VecOperand; 338349cc55cSDimitry Andric unsigned TypeScale = 0; 339349cc55cSDimitry Andric 340349cc55cSDimitry Andric // Look for a vector operand and scale. 341349cc55cSDimitry Andric gep_type_iterator GTI = gep_type_begin(GEP); 342349cc55cSDimitry Andric for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 343349cc55cSDimitry Andric if (!Ops[i]->getType()->isVectorTy()) 344349cc55cSDimitry Andric continue; 345349cc55cSDimitry Andric 346349cc55cSDimitry Andric if (VecOperand) 347349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 348349cc55cSDimitry Andric 349349cc55cSDimitry Andric VecOperand = i; 350349cc55cSDimitry Andric 351349cc55cSDimitry Andric TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType()); 352349cc55cSDimitry Andric if (TS.isScalable()) 353349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 354349cc55cSDimitry Andric 355349cc55cSDimitry Andric TypeScale = TS.getFixedSize(); 356349cc55cSDimitry Andric } 357349cc55cSDimitry Andric 358349cc55cSDimitry Andric // We need to find a vector index to simplify. 359349cc55cSDimitry Andric if (!VecOperand) 360349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 361349cc55cSDimitry Andric 362349cc55cSDimitry Andric // We can't extract the stride if the arithmetic is done at a different size 363349cc55cSDimitry Andric // than the pointer type. Adding the stride later may not wrap correctly. 364349cc55cSDimitry Andric // Technically we could handle wider indices, but I don't expect that in 365349cc55cSDimitry Andric // practice. 366349cc55cSDimitry Andric Value *VecIndex = Ops[*VecOperand]; 367349cc55cSDimitry Andric Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 368349cc55cSDimitry Andric if (VecIndex->getType() != VecIntPtrTy) 369349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 370349cc55cSDimitry Andric 371349cc55cSDimitry Andric Value *Stride; 372349cc55cSDimitry Andric BinaryOperator *Inc; 373349cc55cSDimitry Andric PHINode *BasePhi; 374349cc55cSDimitry Andric if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 375349cc55cSDimitry Andric return std::make_pair(nullptr, nullptr); 376349cc55cSDimitry Andric 377349cc55cSDimitry Andric assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 378349cc55cSDimitry Andric unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 379349cc55cSDimitry Andric assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 380349cc55cSDimitry Andric "Expected one operand of phi to be Inc"); 381349cc55cSDimitry Andric 382349cc55cSDimitry Andric Builder.SetInsertPoint(GEP); 383349cc55cSDimitry Andric 384349cc55cSDimitry Andric // Replace the vector index with the scalar phi and build a scalar GEP. 385349cc55cSDimitry Andric Ops[*VecOperand] = BasePhi; 386349cc55cSDimitry Andric Type *SourceTy = GEP->getSourceElementType(); 387349cc55cSDimitry Andric Value *BasePtr = 388349cc55cSDimitry Andric Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); 389349cc55cSDimitry Andric 390349cc55cSDimitry Andric // Cast the GEP to an i8*. 391349cc55cSDimitry Andric LLVMContext &Ctx = GEP->getContext(); 392349cc55cSDimitry Andric Type *I8PtrTy = 393349cc55cSDimitry Andric Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace()); 394349cc55cSDimitry Andric if (BasePtr->getType() != I8PtrTy) 395349cc55cSDimitry Andric BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy); 396349cc55cSDimitry Andric 397349cc55cSDimitry Andric // Final adjustments to stride should go in the start block. 398349cc55cSDimitry Andric Builder.SetInsertPoint( 399349cc55cSDimitry Andric BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 400349cc55cSDimitry Andric 401349cc55cSDimitry Andric // Convert stride to pointer size if needed. 402349cc55cSDimitry Andric Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 403349cc55cSDimitry Andric assert(Stride->getType() == IntPtrTy && "Unexpected type"); 404349cc55cSDimitry Andric 405349cc55cSDimitry Andric // Scale the stride by the size of the indexed type. 406349cc55cSDimitry Andric if (TypeScale != 1) 407349cc55cSDimitry Andric Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 408349cc55cSDimitry Andric 409349cc55cSDimitry Andric return std::make_pair(BasePtr, Stride); 410349cc55cSDimitry Andric } 411349cc55cSDimitry Andric 412349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 413349cc55cSDimitry Andric Type *DataType, 414349cc55cSDimitry Andric Value *Ptr, 415349cc55cSDimitry Andric Value *AlignOp) { 416349cc55cSDimitry Andric // Make sure the operation will be supported by the backend. 417349cc55cSDimitry Andric if (!isLegalTypeAndAlignment(DataType, AlignOp)) 418349cc55cSDimitry Andric return false; 419349cc55cSDimitry Andric 420349cc55cSDimitry Andric // Pointer should be a GEP. 421349cc55cSDimitry Andric auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 422349cc55cSDimitry Andric if (!GEP) 423349cc55cSDimitry Andric return false; 424349cc55cSDimitry Andric 425349cc55cSDimitry Andric IRBuilder<> Builder(GEP); 426349cc55cSDimitry Andric 427349cc55cSDimitry Andric Value *BasePtr, *Stride; 428349cc55cSDimitry Andric std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); 429349cc55cSDimitry Andric if (!BasePtr) 430349cc55cSDimitry Andric return false; 431349cc55cSDimitry Andric assert(Stride != nullptr); 432349cc55cSDimitry Andric 433349cc55cSDimitry Andric Builder.SetInsertPoint(II); 434349cc55cSDimitry Andric 435349cc55cSDimitry Andric CallInst *Call; 436349cc55cSDimitry Andric if (II->getIntrinsicID() == Intrinsic::masked_gather) 437349cc55cSDimitry Andric Call = Builder.CreateIntrinsic( 438349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_load, 439349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()}, 440349cc55cSDimitry Andric {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 441349cc55cSDimitry Andric else 442349cc55cSDimitry Andric Call = Builder.CreateIntrinsic( 443349cc55cSDimitry Andric Intrinsic::riscv_masked_strided_store, 444349cc55cSDimitry Andric {DataType, BasePtr->getType(), Stride->getType()}, 445349cc55cSDimitry Andric {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 446349cc55cSDimitry Andric 447349cc55cSDimitry Andric Call->takeName(II); 448349cc55cSDimitry Andric II->replaceAllUsesWith(Call); 449349cc55cSDimitry Andric II->eraseFromParent(); 450349cc55cSDimitry Andric 451349cc55cSDimitry Andric if (GEP->use_empty()) 452349cc55cSDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(GEP); 453349cc55cSDimitry Andric 454349cc55cSDimitry Andric return true; 455349cc55cSDimitry Andric } 456349cc55cSDimitry Andric 457349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 458349cc55cSDimitry Andric if (skipFunction(F)) 459349cc55cSDimitry Andric return false; 460349cc55cSDimitry Andric 461349cc55cSDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 462349cc55cSDimitry Andric auto &TM = TPC.getTM<RISCVTargetMachine>(); 463349cc55cSDimitry Andric ST = &TM.getSubtarget<RISCVSubtarget>(F); 464349cc55cSDimitry Andric if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 465349cc55cSDimitry Andric return false; 466349cc55cSDimitry Andric 467349cc55cSDimitry Andric TLI = ST->getTargetLowering(); 468349cc55cSDimitry Andric DL = &F.getParent()->getDataLayout(); 469349cc55cSDimitry Andric LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 470349cc55cSDimitry Andric 471349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Gathers; 472349cc55cSDimitry Andric SmallVector<IntrinsicInst *, 4> Scatters; 473349cc55cSDimitry Andric 474349cc55cSDimitry Andric bool Changed = false; 475349cc55cSDimitry Andric 476349cc55cSDimitry Andric for (BasicBlock &BB : F) { 477349cc55cSDimitry Andric for (Instruction &I : BB) { 478349cc55cSDimitry Andric IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 479349cc55cSDimitry Andric if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 480349cc55cSDimitry Andric isa<FixedVectorType>(II->getType())) { 481349cc55cSDimitry Andric Gathers.push_back(II); 482349cc55cSDimitry Andric } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 483349cc55cSDimitry Andric isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 484349cc55cSDimitry Andric Scatters.push_back(II); 485349cc55cSDimitry Andric } 486349cc55cSDimitry Andric } 487349cc55cSDimitry Andric } 488349cc55cSDimitry Andric 489349cc55cSDimitry Andric // Rewrite gather/scatter to form strided load/store if possible. 490349cc55cSDimitry Andric for (auto *II : Gathers) 491349cc55cSDimitry Andric Changed |= tryCreateStridedLoadStore( 492349cc55cSDimitry Andric II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 493349cc55cSDimitry Andric for (auto *II : Scatters) 494349cc55cSDimitry Andric Changed |= 495349cc55cSDimitry Andric tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 496349cc55cSDimitry Andric II->getArgOperand(1), II->getArgOperand(2)); 497349cc55cSDimitry Andric 498349cc55cSDimitry Andric // Remove any dead phis. 499349cc55cSDimitry Andric while (!MaybeDeadPHIs.empty()) { 500349cc55cSDimitry Andric if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 501349cc55cSDimitry Andric RecursivelyDeleteDeadPHINode(Phi); 502349cc55cSDimitry Andric } 503349cc55cSDimitry Andric 504349cc55cSDimitry Andric return Changed; 505349cc55cSDimitry Andric } 506