1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass custom lowers llvm.gather and llvm.scatter instructions to 10 // RISCV intrinsics. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RISCV.h" 15 #include "RISCVTargetMachine.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/TargetPassConfig.h" 20 #include "llvm/IR/GetElementPtrTypeIterator.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/IntrinsicsRISCV.h" 24 #include "llvm/Transforms/Utils/Local.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "riscv-gather-scatter-lowering" 29 30 namespace { 31 32 class RISCVGatherScatterLowering : public FunctionPass { 33 const RISCVSubtarget *ST = nullptr; 34 const RISCVTargetLowering *TLI = nullptr; 35 LoopInfo *LI = nullptr; 36 const DataLayout *DL = nullptr; 37 38 SmallVector<WeakTrackingVH> MaybeDeadPHIs; 39 40 public: 41 static char ID; // Pass identification, replacement for typeid 42 43 RISCVGatherScatterLowering() : FunctionPass(ID) {} 44 45 bool runOnFunction(Function &F) override; 46 47 void getAnalysisUsage(AnalysisUsage &AU) const override { 48 AU.setPreservesCFG(); 49 AU.addRequired<TargetPassConfig>(); 50 AU.addRequired<LoopInfoWrapperPass>(); 51 } 52 53 StringRef getPassName() const override { 54 return "RISCV gather/scatter lowering"; 55 } 56 57 private: 58 bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp); 59 60 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 61 Value *AlignOp); 62 63 std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP, 64 IRBuilder<> &Builder); 65 66 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 67 PHINode *&BasePtr, BinaryOperator *&Inc, 68 IRBuilder<> &Builder); 69 }; 70 71 } // end anonymous namespace 72 73 char RISCVGatherScatterLowering::ID = 0; 74 75 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 76 "RISCV gather/scatter lowering pass", false, false) 77 78 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 79 return new RISCVGatherScatterLowering(); 80 } 81 82 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType, 83 Value *AlignOp) { 84 Type *ScalarType = DataType->getScalarType(); 85 if (!TLI->isLegalElementTypeForRVV(ScalarType)) 86 return false; 87 88 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 89 if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize()) 90 return false; 91 92 // FIXME: Let the backend type legalize by splitting/widening? 93 EVT DataVT = TLI->getValueType(*DL, DataType); 94 if (!TLI->isTypeLegal(DataVT)) 95 return false; 96 97 return true; 98 } 99 100 // TODO: Should we consider the mask when looking for a stride? 101 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 102 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 103 104 // Check that the start value is a strided constant. 105 auto *StartVal = 106 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 107 if (!StartVal) 108 return std::make_pair(nullptr, nullptr); 109 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 110 ConstantInt *Prev = StartVal; 111 for (unsigned i = 1; i != NumElts; ++i) { 112 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 113 if (!C) 114 return std::make_pair(nullptr, nullptr); 115 116 APInt LocalStride = C->getValue() - Prev->getValue(); 117 if (i == 1) 118 StrideVal = LocalStride; 119 else if (StrideVal != LocalStride) 120 return std::make_pair(nullptr, nullptr); 121 122 Prev = C; 123 } 124 125 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 126 127 return std::make_pair(StartVal, Stride); 128 } 129 130 // Recursively, walk about the use-def chain until we find a Phi with a strided 131 // start value. Build and update a scalar recurrence as we unwind the recursion. 132 // We also update the Stride as we unwind. Our goal is to move all of the 133 // arithmetic out of the loop. 134 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 135 Value *&Stride, 136 PHINode *&BasePtr, 137 BinaryOperator *&Inc, 138 IRBuilder<> &Builder) { 139 // Our base case is a Phi. 140 if (auto *Phi = dyn_cast<PHINode>(Index)) { 141 // A phi node we want to perform this function on should be from the 142 // loop header. 143 if (Phi->getParent() != L->getHeader()) 144 return false; 145 146 Value *Step, *Start; 147 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 148 Inc->getOpcode() != Instruction::Add) 149 return false; 150 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 151 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 152 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 153 "Expected one operand of phi to be Inc"); 154 155 // Only proceed if the step is loop invariant. 156 if (!L->isLoopInvariant(Step)) 157 return false; 158 159 // Step should be a splat. 160 Step = getSplatValue(Step); 161 if (!Step) 162 return false; 163 164 // Start should be a strided constant. 165 auto *StartC = dyn_cast<Constant>(Start); 166 if (!StartC) 167 return false; 168 169 std::tie(Start, Stride) = matchStridedConstant(StartC); 170 if (!Start) 171 return false; 172 assert(Stride != nullptr); 173 174 // Build scalar phi and increment. 175 BasePtr = 176 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); 177 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 178 Inc); 179 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 180 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 181 182 // Note that this Phi might be eligible for removal. 183 MaybeDeadPHIs.push_back(Phi); 184 return true; 185 } 186 187 // Otherwise look for binary operator. 188 auto *BO = dyn_cast<BinaryOperator>(Index); 189 if (!BO) 190 return false; 191 192 if (BO->getOpcode() != Instruction::Add && 193 BO->getOpcode() != Instruction::Or && 194 BO->getOpcode() != Instruction::Mul && 195 BO->getOpcode() != Instruction::Shl) 196 return false; 197 198 // Only support shift by constant. 199 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1))) 200 return false; 201 202 // We need to be able to treat Or as Add. 203 if (BO->getOpcode() == Instruction::Or && 204 !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) 205 return false; 206 207 // We should have one operand in the loop and one splat. 208 Value *OtherOp; 209 if (isa<Instruction>(BO->getOperand(0)) && 210 L->contains(cast<Instruction>(BO->getOperand(0)))) { 211 Index = cast<Instruction>(BO->getOperand(0)); 212 OtherOp = BO->getOperand(1); 213 } else if (isa<Instruction>(BO->getOperand(1)) && 214 L->contains(cast<Instruction>(BO->getOperand(1)))) { 215 Index = cast<Instruction>(BO->getOperand(1)); 216 OtherOp = BO->getOperand(0); 217 } else { 218 return false; 219 } 220 221 // Make sure other op is loop invariant. 222 if (!L->isLoopInvariant(OtherOp)) 223 return false; 224 225 // Make sure we have a splat. 226 Value *SplatOp = getSplatValue(OtherOp); 227 if (!SplatOp) 228 return false; 229 230 // Recurse up the use-def chain. 231 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 232 return false; 233 234 // Locate the Step and Start values from the recurrence. 235 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 236 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 237 Value *Step = Inc->getOperand(StepIndex); 238 Value *Start = BasePtr->getOperand(StartBlock); 239 240 // We need to adjust the start value in the preheader. 241 Builder.SetInsertPoint( 242 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 243 Builder.SetCurrentDebugLocation(DebugLoc()); 244 245 switch (BO->getOpcode()) { 246 default: 247 llvm_unreachable("Unexpected opcode!"); 248 case Instruction::Add: 249 case Instruction::Or: { 250 // An add only affects the start value. It's ok to do this for Or because 251 // we already checked that there are no common set bits. 252 253 // If the start value is Zero, just take the SplatOp. 254 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero()) 255 Start = SplatOp; 256 else 257 Start = Builder.CreateAdd(Start, SplatOp, "start"); 258 BasePtr->setIncomingValue(StartBlock, Start); 259 break; 260 } 261 case Instruction::Mul: { 262 // If the start is zero we don't need to multiply. 263 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 264 Start = Builder.CreateMul(Start, SplatOp, "start"); 265 266 Step = Builder.CreateMul(Step, SplatOp, "step"); 267 268 // If the Stride is 1 just take the SplatOpt. 269 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne()) 270 Stride = SplatOp; 271 else 272 Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 273 Inc->setOperand(StepIndex, Step); 274 BasePtr->setIncomingValue(StartBlock, Start); 275 break; 276 } 277 case Instruction::Shl: { 278 // If the start is zero we don't need to shift. 279 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 280 Start = Builder.CreateShl(Start, SplatOp, "start"); 281 Step = Builder.CreateShl(Step, SplatOp, "step"); 282 Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 283 Inc->setOperand(StepIndex, Step); 284 BasePtr->setIncomingValue(StartBlock, Start); 285 break; 286 } 287 } 288 289 return true; 290 } 291 292 std::pair<Value *, Value *> 293 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, 294 IRBuilder<> &Builder) { 295 296 SmallVector<Value *, 2> Ops(GEP->operands()); 297 298 // Base pointer needs to be a scalar. 299 if (Ops[0]->getType()->isVectorTy()) 300 return std::make_pair(nullptr, nullptr); 301 302 // Make sure we're in a loop and it is in loop simplify form. 303 Loop *L = LI->getLoopFor(GEP->getParent()); 304 if (!L || !L->isLoopSimplifyForm()) 305 return std::make_pair(nullptr, nullptr); 306 307 Optional<unsigned> VecOperand; 308 unsigned TypeScale = 0; 309 310 // Look for a vector operand and scale. 311 gep_type_iterator GTI = gep_type_begin(GEP); 312 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 313 if (!Ops[i]->getType()->isVectorTy()) 314 continue; 315 316 if (VecOperand) 317 return std::make_pair(nullptr, nullptr); 318 319 VecOperand = i; 320 321 TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType()); 322 if (TS.isScalable()) 323 return std::make_pair(nullptr, nullptr); 324 325 TypeScale = TS.getFixedSize(); 326 } 327 328 // We need to find a vector index to simplify. 329 if (!VecOperand) 330 return std::make_pair(nullptr, nullptr); 331 332 // We can't extract the stride if the arithmetic is done at a different size 333 // than the pointer type. Adding the stride later may not wrap correctly. 334 // Technically we could handle wider indices, but I don't expect that in 335 // practice. 336 Value *VecIndex = Ops[*VecOperand]; 337 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 338 if (VecIndex->getType() != VecIntPtrTy) 339 return std::make_pair(nullptr, nullptr); 340 341 Value *Stride; 342 BinaryOperator *Inc; 343 PHINode *BasePhi; 344 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 345 return std::make_pair(nullptr, nullptr); 346 347 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 348 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 349 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 350 "Expected one operand of phi to be Inc"); 351 352 Builder.SetInsertPoint(GEP); 353 354 // Replace the vector index with the scalar phi and build a scalar GEP. 355 Ops[*VecOperand] = BasePhi; 356 Type *SourceTy = GEP->getSourceElementType(); 357 Value *BasePtr = 358 Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); 359 360 // Cast the GEP to an i8*. 361 LLVMContext &Ctx = GEP->getContext(); 362 Type *I8PtrTy = 363 Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace()); 364 if (BasePtr->getType() != I8PtrTy) 365 BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy); 366 367 // Final adjustments to stride should go in the start block. 368 Builder.SetInsertPoint( 369 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 370 371 // Convert stride to pointer size if needed. 372 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 373 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 374 375 // Scale the stride by the size of the indexed type. 376 if (TypeScale != 1) 377 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 378 379 return std::make_pair(BasePtr, Stride); 380 } 381 382 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 383 Type *DataType, 384 Value *Ptr, 385 Value *AlignOp) { 386 // Make sure the operation will be supported by the backend. 387 if (!isLegalTypeAndAlignment(DataType, AlignOp)) 388 return false; 389 390 // Pointer should be a GEP. 391 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 392 if (!GEP) 393 return false; 394 395 IRBuilder<> Builder(GEP); 396 397 Value *BasePtr, *Stride; 398 std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); 399 if (!BasePtr) 400 return false; 401 assert(Stride != nullptr); 402 403 Builder.SetInsertPoint(II); 404 405 CallInst *Call; 406 if (II->getIntrinsicID() == Intrinsic::masked_gather) 407 Call = Builder.CreateIntrinsic( 408 Intrinsic::riscv_masked_strided_load, 409 {DataType, BasePtr->getType(), Stride->getType()}, 410 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 411 else 412 Call = Builder.CreateIntrinsic( 413 Intrinsic::riscv_masked_strided_store, 414 {DataType, BasePtr->getType(), Stride->getType()}, 415 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 416 417 Call->takeName(II); 418 II->replaceAllUsesWith(Call); 419 II->eraseFromParent(); 420 421 if (GEP->use_empty()) 422 RecursivelyDeleteTriviallyDeadInstructions(GEP); 423 424 return true; 425 } 426 427 bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 428 if (skipFunction(F)) 429 return false; 430 431 auto &TPC = getAnalysis<TargetPassConfig>(); 432 auto &TM = TPC.getTM<RISCVTargetMachine>(); 433 ST = &TM.getSubtarget<RISCVSubtarget>(F); 434 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 435 return false; 436 437 TLI = ST->getTargetLowering(); 438 DL = &F.getParent()->getDataLayout(); 439 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 440 441 SmallVector<IntrinsicInst *, 4> Gathers; 442 SmallVector<IntrinsicInst *, 4> Scatters; 443 444 bool Changed = false; 445 446 for (BasicBlock &BB : F) { 447 for (Instruction &I : BB) { 448 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 449 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 450 isa<FixedVectorType>(II->getType())) { 451 Gathers.push_back(II); 452 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 453 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 454 Scatters.push_back(II); 455 } 456 } 457 } 458 459 // Rewrite gather/scatter to form strided load/store if possible. 460 for (auto *II : Gathers) 461 Changed |= tryCreateStridedLoadStore( 462 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 463 for (auto *II : Scatters) 464 Changed |= 465 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 466 II->getArgOperand(1), II->getArgOperand(2)); 467 468 // Remove any dead phis. 469 while (!MaybeDeadPHIs.empty()) { 470 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 471 RecursivelyDeleteDeadPHINode(Phi); 472 } 473 474 return Changed; 475 } 476