1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/ValueTracking.h" 21 #include "llvm/CodeGen/TargetLowering.h" 22 #include "llvm/CodeGen/TargetPassConfig.h" 23 #include "llvm/CodeGen/TargetSubtargetInfo.h" 24 #include "llvm/IR/BasicBlock.h" 25 #include "llvm/IR/Constant.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/DerivedTypes.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/InstrTypes.h" 31 #include "llvm/IR/Instruction.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/IntrinsicInst.h" 34 #include "llvm/IR/Intrinsics.h" 35 #include "llvm/IR/IntrinsicsARM.h" 36 #include "llvm/IR/PatternMatch.h" 37 #include "llvm/IR/Type.h" 38 #include "llvm/IR/Value.h" 39 #include "llvm/InitializePasses.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Casting.h" 42 #include "llvm/Transforms/Utils/Local.h" 43 #include <cassert> 44 45 using namespace llvm; 46 47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" 48 49 cl::opt<bool> EnableMaskedGatherScatters( 50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), 51 cl::desc("Enable the generation of masked gathers and scatters")); 52 53 namespace { 54 55 class MVEGatherScatterLowering : public FunctionPass { 56 public: 57 static char ID; // Pass identification, replacement for typeid 58 59 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 61 } 62 63 bool runOnFunction(Function &F) override; 64 65 StringRef getPassName() const override { 66 return "MVE gather/scatter lowering"; 67 } 68 69 void getAnalysisUsage(AnalysisUsage &AU) const override { 70 AU.setPreservesCFG(); 71 AU.addRequired<TargetPassConfig>(); 72 AU.addRequired<LoopInfoWrapperPass>(); 73 FunctionPass::getAnalysisUsage(AU); 74 } 75 76 private: 77 LoopInfo *LI = nullptr; 78 const DataLayout *DL; 79 80 // Check this is a valid gather with correct alignment 81 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 82 Align Alignment); 83 // Check whether Ptr is hidden behind a bitcast and look through it 84 void lookThroughBitcast(Value *&Ptr); 85 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a 86 // scalar base and vector offsets, or else fallback to using a base of 0 and 87 // offset of Ptr where possible. 88 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale, 89 FixedVectorType *Ty, Type *MemoryTy, 90 IRBuilder<> &Builder); 91 // Check for a getelementptr and deduce base and offsets from it, on success 92 // returning the base directly and the offsets indirectly using the Offsets 93 // argument 94 Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty, 95 GetElementPtrInst *GEP, IRBuilder<> &Builder); 96 // Compute the scale of this gather/scatter instruction 97 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 98 // If the value is a constant, or derived from constants via additions 99 // and multilications, return its numeric value 100 std::optional<int64_t> getIfConst(const Value *V); 101 // If Inst is an add instruction, check whether one summand is a 102 // constant. If so, scale this constant and return it together with 103 // the other summand. 104 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 105 106 Instruction *lowerGather(IntrinsicInst *I); 107 // Create a gather from a base + vector of offsets 108 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 109 Instruction *&Root, 110 IRBuilder<> &Builder); 111 // Create a gather from a vector of pointers 112 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 113 IRBuilder<> &Builder, 114 int64_t Increment = 0); 115 // Create an incrementing gather from a vector of pointers 116 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 117 IRBuilder<> &Builder, 118 int64_t Increment = 0); 119 120 Instruction *lowerScatter(IntrinsicInst *I); 121 // Create a scatter to a base + vector of offsets 122 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 123 IRBuilder<> &Builder); 124 // Create a scatter to a vector of pointers 125 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 126 IRBuilder<> &Builder, 127 int64_t Increment = 0); 128 // Create an incrementing scatter from a vector of pointers 129 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 130 IRBuilder<> &Builder, 131 int64_t Increment = 0); 132 133 // QI gathers and scatters can increment their offsets on their own if 134 // the increment is a constant value (digit) 135 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr, 136 IRBuilder<> &Builder); 137 // QI gathers/scatters can increment their offsets on their own if the 138 // increment is a constant value (digit) - this creates a writeback QI 139 // gather/scatter 140 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 141 Value *Ptr, unsigned TypeScale, 142 IRBuilder<> &Builder); 143 144 // Optimise the base and offsets of the given address 145 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); 146 // Try to fold consecutive geps together into one 147 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale, 148 IRBuilder<> &Builder); 149 // Check whether these offsets could be moved out of the loop they're in 150 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 151 // Pushes the given add out of the loop 152 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 153 // Pushes the given mul or shl out of the loop 154 void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound, 155 Value *OffsSecondOperand, unsigned LoopIncrement, 156 IRBuilder<> &Builder); 157 }; 158 159 } // end anonymous namespace 160 161 char MVEGatherScatterLowering::ID = 0; 162 163 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 164 "MVE gather/scattering lowering pass", false, false) 165 166 Pass *llvm::createMVEGatherScatterLoweringPass() { 167 return new MVEGatherScatterLowering(); 168 } 169 170 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 171 unsigned ElemSize, 172 Align Alignment) { 173 if (((NumElements == 4 && 174 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 175 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 176 (NumElements == 16 && ElemSize == 8)) && 177 Alignment >= ElemSize / 8) 178 return true; 179 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 180 << "valid alignment or vector type \n"); 181 return false; 182 } 183 184 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { 185 // Offsets that are not of type <N x i32> are sign extended by the 186 // getelementptr instruction, and MVE gathers/scatters treat the offset as 187 // unsigned. Thus, if the element size is smaller than 32, we can only allow 188 // positive offsets - i.e., the offsets are not allowed to be variables we 189 // can't look into. 190 // Additionally, <N x i32> offsets have to either originate from a zext of a 191 // vector with element types smaller or equal the type of the gather we're 192 // looking at, or consist of constants that we can check are small enough 193 // to fit into the gather type. 194 // Thus we check that 0 < value < 2^TargetElemSize. 195 unsigned TargetElemSize = 128 / TargetElemCount; 196 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType()) 197 ->getElementType() 198 ->getScalarSizeInBits(); 199 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { 200 Constant *ConstOff = dyn_cast<Constant>(Offsets); 201 if (!ConstOff) 202 return false; 203 int64_t TargetElemMaxSize = (1ULL << TargetElemSize); 204 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { 205 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem); 206 if (!OConst) 207 return false; 208 int SExtValue = OConst->getSExtValue(); 209 if (SExtValue >= TargetElemMaxSize || SExtValue < 0) 210 return false; 211 return true; 212 }; 213 if (isa<FixedVectorType>(ConstOff->getType())) { 214 for (unsigned i = 0; i < TargetElemCount; i++) { 215 if (!CheckValueSize(ConstOff->getAggregateElement(i))) 216 return false; 217 } 218 } else { 219 if (!CheckValueSize(ConstOff)) 220 return false; 221 } 222 } 223 return true; 224 } 225 226 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets, 227 int &Scale, FixedVectorType *Ty, 228 Type *MemoryTy, 229 IRBuilder<> &Builder) { 230 if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { 231 if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) { 232 Scale = 233 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), 234 MemoryTy->getScalarSizeInBits()); 235 return Scale == -1 ? nullptr : V; 236 } 237 } 238 239 // If we couldn't use the GEP (or it doesn't exist), attempt to use a 240 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4 241 // elements. 242 FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType()); 243 if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32) 244 return nullptr; 245 Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0); 246 Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getPtrTy()); 247 Offsets = Builder.CreatePtrToInt( 248 Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4)); 249 Scale = 0; 250 return BasePtr; 251 } 252 253 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets, 254 FixedVectorType *Ty, 255 GetElementPtrInst *GEP, 256 IRBuilder<> &Builder) { 257 if (!GEP) { 258 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer " 259 << "found\n"); 260 return nullptr; 261 } 262 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 263 << " Looking at intrinsic for base + vector of offsets\n"); 264 Value *GEPPtr = GEP->getPointerOperand(); 265 Offsets = GEP->getOperand(1); 266 if (GEPPtr->getType()->isVectorTy() || 267 !isa<FixedVectorType>(Offsets->getType())) 268 return nullptr; 269 270 if (GEP->getNumOperands() != 2) { 271 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 272 << " operands. Expanding.\n"); 273 return nullptr; 274 } 275 Offsets = GEP->getOperand(1); 276 unsigned OffsetsElemCount = 277 cast<FixedVectorType>(Offsets->getType())->getNumElements(); 278 // Paranoid check whether the number of parallel lanes is the same 279 assert(Ty->getNumElements() == OffsetsElemCount); 280 281 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); 282 if (ZextOffs) 283 Offsets = ZextOffs->getOperand(0); 284 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); 285 286 // If the offsets are already being zext-ed to <N x i32>, that relieves us of 287 // having to make sure that they won't overflow. 288 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) 289 ->getElementType() 290 ->getScalarSizeInBits() != 32) 291 if (!checkOffsetSize(Offsets, OffsetsElemCount)) 292 return nullptr; 293 294 // The offset sizes have been checked; if any truncating or zext-ing is 295 // required to fix them, do that now 296 if (Ty != Offsets->getType()) { 297 if ((Ty->getElementType()->getScalarSizeInBits() < 298 OffsetType->getElementType()->getScalarSizeInBits())) { 299 Offsets = Builder.CreateTrunc(Offsets, Ty); 300 } else { 301 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); 302 } 303 } 304 // If none of the checks failed, return the gep's base pointer 305 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 306 return GEPPtr; 307 } 308 309 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 310 // Look through bitcast instruction if #elements is the same 311 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 312 auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 313 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 314 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 315 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through " 316 << "bitcast\n"); 317 Ptr = BitCast->getOperand(0); 318 } 319 } 320 } 321 322 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 323 unsigned MemoryElemSize) { 324 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 325 // or a 8bit, 16bit or 32bit load/store scaled by 1 326 if (GEPElemSize == 32 && MemoryElemSize == 32) 327 return 2; 328 else if (GEPElemSize == 16 && MemoryElemSize == 16) 329 return 1; 330 else if (GEPElemSize == 8) 331 return 0; 332 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 333 << "create intrinsic\n"); 334 return -1; 335 } 336 337 std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 338 const Constant *C = dyn_cast<Constant>(V); 339 if (C && C->getSplatValue()) 340 return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 341 if (!isa<Instruction>(V)) 342 return std::optional<int64_t>{}; 343 344 const Instruction *I = cast<Instruction>(V); 345 if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or || 346 I->getOpcode() == Instruction::Mul || 347 I->getOpcode() == Instruction::Shl) { 348 std::optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 349 std::optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 350 if (!Op0 || !Op1) 351 return std::optional<int64_t>{}; 352 if (I->getOpcode() == Instruction::Add) 353 return std::optional<int64_t>{*Op0 + *Op1}; 354 if (I->getOpcode() == Instruction::Mul) 355 return std::optional<int64_t>{*Op0 * *Op1}; 356 if (I->getOpcode() == Instruction::Shl) 357 return std::optional<int64_t>{*Op0 << *Op1}; 358 if (I->getOpcode() == Instruction::Or) 359 return std::optional<int64_t>{*Op0 | *Op1}; 360 } 361 return std::optional<int64_t>{}; 362 } 363 364 // Return true if I is an Or instruction that is equivalent to an add, due to 365 // the operands having no common bits set. 366 static bool isAddLikeOr(Instruction *I, const DataLayout &DL) { 367 return I->getOpcode() == Instruction::Or && 368 haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL); 369 } 370 371 std::pair<Value *, int64_t> 372 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 373 std::pair<Value *, int64_t> ReturnFalse = 374 std::pair<Value *, int64_t>(nullptr, 0); 375 // At this point, the instruction we're looking at must be an add or an 376 // add-like-or. 377 Instruction *Add = dyn_cast<Instruction>(Inst); 378 if (Add == nullptr || 379 (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL))) 380 return ReturnFalse; 381 382 Value *Summand; 383 std::optional<int64_t> Const; 384 // Find out which operand the value that is increased is 385 if ((Const = getIfConst(Add->getOperand(0)))) 386 Summand = Add->getOperand(1); 387 else if ((Const = getIfConst(Add->getOperand(1)))) 388 Summand = Add->getOperand(0); 389 else 390 return ReturnFalse; 391 392 // Check that the constant is small enough for an incrementing gather 393 int64_t Immediate = *Const << TypeScale; 394 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 395 return ReturnFalse; 396 397 return std::pair<Value *, int64_t>(Summand, Immediate); 398 } 399 400 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 401 using namespace PatternMatch; 402 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n" 403 << *I << "\n"); 404 405 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 406 // Attempt to turn the masked gather in I into a MVE intrinsic 407 // Potentially optimising the addressing modes as we do so. 408 auto *Ty = cast<FixedVectorType>(I->getType()); 409 Value *Ptr = I->getArgOperand(0); 410 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 411 Value *Mask = I->getArgOperand(2); 412 Value *PassThru = I->getArgOperand(3); 413 414 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 415 Alignment)) 416 return nullptr; 417 lookThroughBitcast(Ptr); 418 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 419 420 IRBuilder<> Builder(I->getContext()); 421 Builder.SetInsertPoint(I); 422 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 423 424 Instruction *Root = I; 425 426 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder); 427 if (!Load) 428 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 429 if (!Load) 430 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 431 if (!Load) 432 return nullptr; 433 434 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 435 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 436 << "creating select\n"); 437 Load = SelectInst::Create(Mask, Load, PassThru); 438 Builder.Insert(Load); 439 } 440 441 Root->replaceAllUsesWith(Load); 442 Root->eraseFromParent(); 443 if (Root != I) 444 // If this was an extending gather, we need to get rid of the sext/zext 445 // sext/zext as well as of the gather itself 446 I->eraseFromParent(); 447 448 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n" 449 << *Load << "\n"); 450 return Load; 451 } 452 453 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase( 454 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 455 using namespace PatternMatch; 456 auto *Ty = cast<FixedVectorType>(I->getType()); 457 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 458 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 459 // Can't build an intrinsic for this 460 return nullptr; 461 Value *Mask = I->getArgOperand(2); 462 if (match(Mask, m_One())) 463 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 464 {Ty, Ptr->getType()}, 465 {Ptr, Builder.getInt32(Increment)}); 466 else 467 return Builder.CreateIntrinsic( 468 Intrinsic::arm_mve_vldr_gather_base_predicated, 469 {Ty, Ptr->getType(), Mask->getType()}, 470 {Ptr, Builder.getInt32(Increment), Mask}); 471 } 472 473 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 474 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 475 using namespace PatternMatch; 476 auto *Ty = cast<FixedVectorType>(I->getType()); 477 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with " 478 << "writeback\n"); 479 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 480 // Can't build an intrinsic for this 481 return nullptr; 482 Value *Mask = I->getArgOperand(2); 483 if (match(Mask, m_One())) 484 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 485 {Ty, Ptr->getType()}, 486 {Ptr, Builder.getInt32(Increment)}); 487 else 488 return Builder.CreateIntrinsic( 489 Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 490 {Ty, Ptr->getType(), Mask->getType()}, 491 {Ptr, Builder.getInt32(Increment), Mask}); 492 } 493 494 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 495 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 496 using namespace PatternMatch; 497 498 Type *MemoryTy = I->getType(); 499 Type *ResultTy = MemoryTy; 500 501 unsigned Unsigned = 1; 502 // The size of the gather was already checked in isLegalTypeAndAlignment; 503 // if it was not a full vector width an appropriate extend should follow. 504 auto *Extend = Root; 505 bool TruncResult = false; 506 if (MemoryTy->getPrimitiveSizeInBits() < 128) { 507 if (I->hasOneUse()) { 508 // If the gather has a single extend of the correct type, use an extending 509 // gather and replace the ext. In which case the correct root to replace 510 // is not the CallInst itself, but the instruction which extends it. 511 Instruction* User = cast<Instruction>(*I->users().begin()); 512 if (isa<SExtInst>(User) && 513 User->getType()->getPrimitiveSizeInBits() == 128) { 514 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 515 << *User << "\n"); 516 Extend = User; 517 ResultTy = User->getType(); 518 Unsigned = 0; 519 } else if (isa<ZExtInst>(User) && 520 User->getType()->getPrimitiveSizeInBits() == 128) { 521 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 522 << *ResultTy << "\n"); 523 Extend = User; 524 ResultTy = User->getType(); 525 } 526 } 527 528 // If an extend hasn't been found and the type is an integer, create an 529 // extending gather and truncate back to the original type. 530 if (ResultTy->getPrimitiveSizeInBits() < 128 && 531 ResultTy->isIntOrIntVectorTy()) { 532 ResultTy = ResultTy->getWithNewBitWidth( 533 128 / cast<FixedVectorType>(ResultTy)->getNumElements()); 534 TruncResult = true; 535 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: " 536 << *ResultTy << "\n"); 537 } 538 539 // The final size of the gather must be a full vector width 540 if (ResultTy->getPrimitiveSizeInBits() != 128) { 541 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided " 542 "from the correct type. Expanding\n"); 543 return nullptr; 544 } 545 } 546 547 Value *Offsets; 548 int Scale; 549 Value *BasePtr = decomposePtr( 550 Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder); 551 if (!BasePtr) 552 return nullptr; 553 554 Root = Extend; 555 Value *Mask = I->getArgOperand(2); 556 Instruction *Load = nullptr; 557 if (!match(Mask, m_One())) 558 Load = Builder.CreateIntrinsic( 559 Intrinsic::arm_mve_vldr_gather_offset_predicated, 560 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 561 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 562 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 563 else 564 Load = Builder.CreateIntrinsic( 565 Intrinsic::arm_mve_vldr_gather_offset, 566 {ResultTy, BasePtr->getType(), Offsets->getType()}, 567 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 568 Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 569 570 if (TruncResult) { 571 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy); 572 Builder.Insert(Load); 573 } 574 return Load; 575 } 576 577 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 578 using namespace PatternMatch; 579 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n" 580 << *I << "\n"); 581 582 // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 583 // Attempt to turn the masked scatter in I into a MVE intrinsic 584 // Potentially optimising the addressing modes as we do so. 585 Value *Input = I->getArgOperand(0); 586 Value *Ptr = I->getArgOperand(1); 587 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 588 auto *Ty = cast<FixedVectorType>(Input->getType()); 589 590 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 591 Alignment)) 592 return nullptr; 593 594 lookThroughBitcast(Ptr); 595 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 596 597 IRBuilder<> Builder(I->getContext()); 598 Builder.SetInsertPoint(I); 599 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 600 601 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder); 602 if (!Store) 603 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 604 if (!Store) 605 Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 606 if (!Store) 607 return nullptr; 608 609 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n" 610 << *Store << "\n"); 611 I->eraseFromParent(); 612 return Store; 613 } 614 615 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 616 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 617 using namespace PatternMatch; 618 Value *Input = I->getArgOperand(0); 619 auto *Ty = cast<FixedVectorType>(Input->getType()); 620 // Only QR variants allow truncating 621 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 622 // Can't build an intrinsic for this 623 return nullptr; 624 } 625 Value *Mask = I->getArgOperand(3); 626 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 627 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 628 if (match(Mask, m_One())) 629 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 630 {Ptr->getType(), Input->getType()}, 631 {Ptr, Builder.getInt32(Increment), Input}); 632 else 633 return Builder.CreateIntrinsic( 634 Intrinsic::arm_mve_vstr_scatter_base_predicated, 635 {Ptr->getType(), Input->getType(), Mask->getType()}, 636 {Ptr, Builder.getInt32(Increment), Input, Mask}); 637 } 638 639 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 640 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 641 using namespace PatternMatch; 642 Value *Input = I->getArgOperand(0); 643 auto *Ty = cast<FixedVectorType>(Input->getType()); 644 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers " 645 << "with writeback\n"); 646 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 647 // Can't build an intrinsic for this 648 return nullptr; 649 Value *Mask = I->getArgOperand(3); 650 if (match(Mask, m_One())) 651 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 652 {Ptr->getType(), Input->getType()}, 653 {Ptr, Builder.getInt32(Increment), Input}); 654 else 655 return Builder.CreateIntrinsic( 656 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 657 {Ptr->getType(), Input->getType(), Mask->getType()}, 658 {Ptr, Builder.getInt32(Increment), Input, Mask}); 659 } 660 661 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 662 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 663 using namespace PatternMatch; 664 Value *Input = I->getArgOperand(0); 665 Value *Mask = I->getArgOperand(3); 666 Type *InputTy = Input->getType(); 667 Type *MemoryTy = InputTy; 668 669 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 670 << " to base + vector of offsets\n"); 671 // If the input has been truncated, try to integrate that trunc into the 672 // scatter instruction (we don't care about alignment here) 673 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 674 Value *PreTrunc = Trunc->getOperand(0); 675 Type *PreTruncTy = PreTrunc->getType(); 676 if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 677 Input = PreTrunc; 678 InputTy = PreTruncTy; 679 } 680 } 681 bool ExtendInput = false; 682 if (InputTy->getPrimitiveSizeInBits() < 128 && 683 InputTy->isIntOrIntVectorTy()) { 684 // If we can't find a trunc to incorporate into the instruction, create an 685 // implicit one with a zext, so that we can still create a scatter. We know 686 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type 687 // smaller than 128 bits will divide evenly into a 128bit vector. 688 InputTy = InputTy->getWithNewBitWidth( 689 128 / cast<FixedVectorType>(InputTy)->getNumElements()); 690 ExtendInput = true; 691 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n" 692 << *Input << "\n"); 693 } 694 if (InputTy->getPrimitiveSizeInBits() != 128) { 695 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for " 696 "non-standard input types. Expanding.\n"); 697 return nullptr; 698 } 699 700 Value *Offsets; 701 int Scale; 702 Value *BasePtr = decomposePtr( 703 Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder); 704 if (!BasePtr) 705 return nullptr; 706 707 if (ExtendInput) 708 Input = Builder.CreateZExt(Input, InputTy); 709 if (!match(Mask, m_One())) 710 return Builder.CreateIntrinsic( 711 Intrinsic::arm_mve_vstr_scatter_offset_predicated, 712 {BasePtr->getType(), Offsets->getType(), Input->getType(), 713 Mask->getType()}, 714 {BasePtr, Offsets, Input, 715 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 716 Builder.getInt32(Scale), Mask}); 717 else 718 return Builder.CreateIntrinsic( 719 Intrinsic::arm_mve_vstr_scatter_offset, 720 {BasePtr->getType(), Offsets->getType(), Input->getType()}, 721 {BasePtr, Offsets, Input, 722 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 723 Builder.getInt32(Scale)}); 724 } 725 726 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 727 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 728 FixedVectorType *Ty; 729 if (I->getIntrinsicID() == Intrinsic::masked_gather) 730 Ty = cast<FixedVectorType>(I->getType()); 731 else 732 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 733 734 // Incrementing gathers only exist for v4i32 735 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 736 return nullptr; 737 // Incrementing gathers are not beneficial outside of a loop 738 Loop *L = LI->getLoopFor(I->getParent()); 739 if (L == nullptr) 740 return nullptr; 741 742 // Decompose the GEP into Base and Offsets 743 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 744 Value *Offsets; 745 Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder); 746 if (!BasePtr) 747 return nullptr; 748 749 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 750 "wb gather/scatter\n"); 751 752 // The gep was in charge of making sure the offsets are scaled correctly 753 // - calculate that factor so it can be applied by hand 754 int TypeScale = 755 computeScale(DL->getTypeSizeInBits(GEP->getSourceElementType()), 756 DL->getTypeSizeInBits(GEP->getType()) / 757 cast<FixedVectorType>(GEP->getType())->getNumElements()); 758 if (TypeScale == -1) 759 return nullptr; 760 761 if (GEP->hasOneUse()) { 762 // Only in this case do we want to build a wb gather, because the wb will 763 // change the phi which does affect other users of the gep (which will still 764 // be using the phi in the old way) 765 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, 766 TypeScale, Builder)) 767 return Load; 768 } 769 770 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 771 "non-wb gather/scatter\n"); 772 773 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 774 if (Add.first == nullptr) 775 return nullptr; 776 Value *OffsetsIncoming = Add.first; 777 int64_t Immediate = Add.second; 778 779 // Make sure the offsets are scaled correctly 780 Instruction *ScaledOffsets = BinaryOperator::Create( 781 Instruction::Shl, OffsetsIncoming, 782 Builder.CreateVectorSplat(Ty->getNumElements(), 783 Builder.getInt32(TypeScale)), 784 "ScaledIndex", I->getIterator()); 785 // Add the base to the offsets 786 OffsetsIncoming = BinaryOperator::Create( 787 Instruction::Add, ScaledOffsets, 788 Builder.CreateVectorSplat( 789 Ty->getNumElements(), 790 Builder.CreatePtrToInt( 791 BasePtr, 792 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 793 "StartIndex", I->getIterator()); 794 795 if (I->getIntrinsicID() == Intrinsic::masked_gather) 796 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate); 797 else 798 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate); 799 } 800 801 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 802 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 803 IRBuilder<> &Builder) { 804 // Check whether this gather's offset is incremented by a constant - if so, 805 // and the load is of the right type, we can merge this into a QI gather 806 Loop *L = LI->getLoopFor(I->getParent()); 807 // Offsets that are worth merging into this instruction will be incremented 808 // by a constant, thus we're looking for an add of a phi and a constant 809 PHINode *Phi = dyn_cast<PHINode>(Offsets); 810 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 811 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 812 // No phi means no IV to write back to; if there is a phi, we expect it 813 // to have exactly two incoming values; the only phis we are interested in 814 // will be loop IV's and have exactly two uses, one in their increment and 815 // one in the gather's gep 816 return nullptr; 817 818 unsigned IncrementIndex = 819 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 820 // Look through the phi to the phi increment 821 Offsets = Phi->getIncomingValue(IncrementIndex); 822 823 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 824 if (Add.first == nullptr) 825 return nullptr; 826 Value *OffsetsIncoming = Add.first; 827 int64_t Immediate = Add.second; 828 if (OffsetsIncoming != Phi) 829 // Then the increment we are looking at is not an increment of the 830 // induction variable, and we don't want to do a writeback 831 return nullptr; 832 833 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 834 unsigned NumElems = 835 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 836 837 // Make sure the offsets are scaled correctly 838 Instruction *ScaledOffsets = BinaryOperator::Create( 839 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 840 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 841 "ScaledIndex", 842 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator()); 843 // Add the base to the offsets 844 OffsetsIncoming = BinaryOperator::Create( 845 Instruction::Add, ScaledOffsets, 846 Builder.CreateVectorSplat( 847 NumElems, 848 Builder.CreatePtrToInt( 849 BasePtr, 850 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 851 "StartIndex", 852 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator()); 853 // The gather is pre-incrementing 854 OffsetsIncoming = BinaryOperator::Create( 855 Instruction::Sub, OffsetsIncoming, 856 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 857 "PreIncrementStartIndex", 858 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator()); 859 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 860 861 Builder.SetInsertPoint(I); 862 863 Instruction *EndResult; 864 Instruction *NewInduction; 865 if (I->getIntrinsicID() == Intrinsic::masked_gather) { 866 // Build the incrementing gather 867 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 868 // One value to be handed to whoever uses the gather, one is the loop 869 // increment 870 EndResult = ExtractValueInst::Create(Load, 0, "Gather"); 871 NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement"); 872 Builder.Insert(EndResult); 873 Builder.Insert(NewInduction); 874 } else { 875 // Build the incrementing scatter 876 EndResult = NewInduction = 877 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 878 } 879 Instruction *AddInst = cast<Instruction>(Offsets); 880 AddInst->replaceAllUsesWith(NewInduction); 881 AddInst->eraseFromParent(); 882 Phi->setIncomingValue(IncrementIndex, NewInduction); 883 884 return EndResult; 885 } 886 887 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 888 Value *OffsSecondOperand, 889 unsigned StartIndex) { 890 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 891 BasicBlock::iterator InsertionPoint = 892 Phi->getIncomingBlock(StartIndex)->back().getIterator(); 893 // Initialize the phi with a vector that contains a sum of the constants 894 Instruction *NewIndex = BinaryOperator::Create( 895 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 896 "PushedOutAdd", InsertionPoint); 897 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 898 899 // Order such that start index comes first (this reduces mov's) 900 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 901 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 902 Phi->getIncomingBlock(IncrementIndex)); 903 Phi->removeIncomingValue(1); 904 Phi->removeIncomingValue((unsigned)0); 905 } 906 907 void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi, 908 Value *IncrementPerRound, 909 Value *OffsSecondOperand, 910 unsigned LoopIncrement, 911 IRBuilder<> &Builder) { 912 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 913 914 // Create a new scalar add outside of the loop and transform it to a splat 915 // by which loop variable can be incremented 916 BasicBlock::iterator InsertionPoint = 917 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator(); 918 919 // Create a new index 920 Value *StartIndex = 921 BinaryOperator::Create((Instruction::BinaryOps)Opcode, 922 Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 923 OffsSecondOperand, "PushedOutMul", InsertionPoint); 924 925 Instruction *Product = 926 BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound, 927 OffsSecondOperand, "Product", InsertionPoint); 928 929 BasicBlock::iterator NewIncrInsertPt = 930 Phi->getIncomingBlock(LoopIncrement)->back().getIterator(); 931 NewIncrInsertPt = std::prev(NewIncrInsertPt); 932 933 // Increment NewIndex by Product instead of the multiplication 934 Instruction *NewIncrement = BinaryOperator::Create( 935 Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt); 936 937 Phi->addIncoming(StartIndex, 938 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 939 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 940 Phi->removeIncomingValue((unsigned)0); 941 Phi->removeIncomingValue((unsigned)0); 942 } 943 944 // Check whether all usages of this instruction are as offsets of 945 // gathers/scatters or simple arithmetics only used by gathers/scatters 946 static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) { 947 if (I->hasNUses(0)) { 948 return false; 949 } 950 bool Gatscat = true; 951 for (User *U : I->users()) { 952 if (!isa<Instruction>(U)) 953 return false; 954 if (isa<GetElementPtrInst>(U) || 955 isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 956 return Gatscat; 957 } else { 958 unsigned OpCode = cast<Instruction>(U)->getOpcode(); 959 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul || 960 OpCode == Instruction::Shl || 961 isAddLikeOr(cast<Instruction>(U), DL)) && 962 hasAllGatScatUsers(cast<Instruction>(U), DL)) { 963 continue; 964 } 965 return false; 966 } 967 } 968 return Gatscat; 969 } 970 971 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 972 LoopInfo *LI) { 973 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: " 974 << *Offsets << "\n"); 975 // Optimise the addresses of gathers/scatters by moving invariant 976 // calculations out of the loop 977 if (!isa<Instruction>(Offsets)) 978 return false; 979 Instruction *Offs = cast<Instruction>(Offsets); 980 if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) && 981 Offs->getOpcode() != Instruction::Mul && 982 Offs->getOpcode() != Instruction::Shl) 983 return false; 984 Loop *L = LI->getLoopFor(BB); 985 if (L == nullptr) 986 return false; 987 if (!Offs->hasOneUse()) { 988 if (!hasAllGatScatUsers(Offs, *DL)) 989 return false; 990 } 991 992 // Find out which, if any, operand of the instruction 993 // is a phi node 994 PHINode *Phi; 995 int OffsSecondOp; 996 if (isa<PHINode>(Offs->getOperand(0))) { 997 Phi = cast<PHINode>(Offs->getOperand(0)); 998 OffsSecondOp = 1; 999 } else if (isa<PHINode>(Offs->getOperand(1))) { 1000 Phi = cast<PHINode>(Offs->getOperand(1)); 1001 OffsSecondOp = 0; 1002 } else { 1003 bool Changed = false; 1004 if (isa<Instruction>(Offs->getOperand(0)) && 1005 L->contains(cast<Instruction>(Offs->getOperand(0)))) 1006 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 1007 if (isa<Instruction>(Offs->getOperand(1)) && 1008 L->contains(cast<Instruction>(Offs->getOperand(1)))) 1009 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 1010 if (!Changed) 1011 return false; 1012 if (isa<PHINode>(Offs->getOperand(0))) { 1013 Phi = cast<PHINode>(Offs->getOperand(0)); 1014 OffsSecondOp = 1; 1015 } else if (isa<PHINode>(Offs->getOperand(1))) { 1016 Phi = cast<PHINode>(Offs->getOperand(1)); 1017 OffsSecondOp = 0; 1018 } else { 1019 return false; 1020 } 1021 } 1022 // A phi node we want to perform this function on should be from the 1023 // loop header. 1024 if (Phi->getParent() != L->getHeader()) 1025 return false; 1026 1027 // We're looking for a simple add recurrence. 1028 BinaryOperator *IncInstruction; 1029 Value *Start, *IncrementPerRound; 1030 if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) || 1031 IncInstruction->getOpcode() != Instruction::Add) 1032 return false; 1033 1034 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1; 1035 1036 // Get the value that is added to/multiplied with the phi 1037 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 1038 1039 if (IncrementPerRound->getType() != OffsSecondOperand->getType() || 1040 !L->isLoopInvariant(OffsSecondOperand)) 1041 // Something has gone wrong, abort 1042 return false; 1043 1044 // Only proceed if the increment per round is a constant or an instruction 1045 // which does not originate from within the loop 1046 if (!isa<Constant>(IncrementPerRound) && 1047 !(isa<Instruction>(IncrementPerRound) && 1048 !L->contains(cast<Instruction>(IncrementPerRound)))) 1049 return false; 1050 1051 // If the phi is not used by anything else, we can just adapt it when 1052 // replacing the instruction; if it is, we'll have to duplicate it 1053 PHINode *NewPhi; 1054 if (Phi->getNumUses() == 2) { 1055 // No other users -> reuse existing phi (One user is the instruction 1056 // we're looking at, the other is the phi increment) 1057 if (IncInstruction->getNumUses() != 1) { 1058 // If the incrementing instruction does have more users than 1059 // our phi, we need to copy it 1060 IncInstruction = BinaryOperator::Create( 1061 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 1062 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator()); 1063 Phi->setIncomingValue(IncrementingBlock, IncInstruction); 1064 } 1065 NewPhi = Phi; 1066 } else { 1067 // There are other users -> create a new phi 1068 NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi->getIterator()); 1069 // Copy the incoming values of the old phi 1070 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 1071 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 1072 IncInstruction = BinaryOperator::Create( 1073 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 1074 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator()); 1075 NewPhi->addIncoming(IncInstruction, 1076 Phi->getIncomingBlock(IncrementingBlock)); 1077 IncrementingBlock = 1; 1078 } 1079 1080 IRBuilder<> Builder(BB->getContext()); 1081 Builder.SetInsertPoint(Phi); 1082 Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 1083 1084 switch (Offs->getOpcode()) { 1085 case Instruction::Add: 1086 case Instruction::Or: 1087 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 1088 break; 1089 case Instruction::Mul: 1090 case Instruction::Shl: 1091 pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound, 1092 OffsSecondOperand, IncrementingBlock, Builder); 1093 break; 1094 default: 1095 return false; 1096 } 1097 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable " 1098 << "add/mul\n"); 1099 1100 // The instruction has now been "absorbed" into the phi value 1101 Offs->replaceAllUsesWith(NewPhi); 1102 if (Offs->hasNUses(0)) 1103 Offs->eraseFromParent(); 1104 // Clean up the old increment in case it's unused because we built a new 1105 // one 1106 if (IncInstruction->hasNUses(0)) 1107 IncInstruction->eraseFromParent(); 1108 1109 return true; 1110 } 1111 1112 static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y, 1113 unsigned ScaleY, IRBuilder<> &Builder) { 1114 // Splat the non-vector value to a vector of the given type - if the value is 1115 // a constant (and its value isn't too big), we can even use this opportunity 1116 // to scale it to the size of the vector elements 1117 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { 1118 ConstantInt *Const; 1119 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) && 1120 VT->getElementType() != NonVectorVal->getType()) { 1121 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); 1122 uint64_t N = Const->getZExtValue(); 1123 if (N < (unsigned)(1 << (TargetElemSize - 1))) { 1124 NonVectorVal = Builder.CreateVectorSplat( 1125 VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); 1126 return; 1127 } 1128 } 1129 NonVectorVal = 1130 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); 1131 }; 1132 1133 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType()); 1134 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType()); 1135 // If one of X, Y is not a vector, we have to splat it in order 1136 // to add the two of them. 1137 if (XElType && !YElType) { 1138 FixSummands(XElType, Y); 1139 YElType = cast<FixedVectorType>(Y->getType()); 1140 } else if (YElType && !XElType) { 1141 FixSummands(YElType, X); 1142 XElType = cast<FixedVectorType>(X->getType()); 1143 } 1144 assert(XElType && YElType && "Unknown vector types"); 1145 // Check that the summands are of compatible types 1146 if (XElType != YElType) { 1147 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); 1148 return nullptr; 1149 } 1150 1151 if (XElType->getElementType()->getScalarSizeInBits() != 32) { 1152 // Check that by adding the vectors we do not accidentally 1153 // create an overflow 1154 Constant *ConstX = dyn_cast<Constant>(X); 1155 Constant *ConstY = dyn_cast<Constant>(Y); 1156 if (!ConstX || !ConstY) 1157 return nullptr; 1158 unsigned TargetElemSize = 128 / XElType->getNumElements(); 1159 for (unsigned i = 0; i < XElType->getNumElements(); i++) { 1160 ConstantInt *ConstXEl = 1161 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i)); 1162 ConstantInt *ConstYEl = 1163 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i)); 1164 if (!ConstXEl || !ConstYEl || 1165 ConstXEl->getZExtValue() * ScaleX + 1166 ConstYEl->getZExtValue() * ScaleY >= 1167 (unsigned)(1 << (TargetElemSize - 1))) 1168 return nullptr; 1169 } 1170 } 1171 1172 Value *XScale = Builder.CreateVectorSplat( 1173 XElType->getNumElements(), 1174 Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX)); 1175 Value *YScale = Builder.CreateVectorSplat( 1176 YElType->getNumElements(), 1177 Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY)); 1178 Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale), 1179 Builder.CreateMul(Y, YScale)); 1180 1181 if (checkOffsetSize(Add, XElType->getNumElements())) 1182 return Add; 1183 else 1184 return nullptr; 1185 } 1186 1187 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, 1188 Value *&Offsets, unsigned &Scale, 1189 IRBuilder<> &Builder) { 1190 Value *GEPPtr = GEP->getPointerOperand(); 1191 Offsets = GEP->getOperand(1); 1192 Scale = DL->getTypeAllocSize(GEP->getSourceElementType()); 1193 // We only merge geps with constant offsets, because only for those 1194 // we can make sure that we do not cause an overflow 1195 if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets)) 1196 return nullptr; 1197 if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) { 1198 // Merge the two geps into one 1199 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder); 1200 if (!BaseBasePtr) 1201 return nullptr; 1202 Offsets = CheckAndCreateOffsetAdd( 1203 Offsets, Scale, GEP->getOperand(1), 1204 DL->getTypeAllocSize(GEP->getSourceElementType()), Builder); 1205 if (Offsets == nullptr) 1206 return nullptr; 1207 Scale = 1; // Scale is always an i8 at this point. 1208 return BaseBasePtr; 1209 } 1210 return GEPPtr; 1211 } 1212 1213 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, 1214 LoopInfo *LI) { 1215 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address); 1216 if (!GEP) 1217 return false; 1218 bool Changed = false; 1219 if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) { 1220 IRBuilder<> Builder(GEP->getContext()); 1221 Builder.SetInsertPoint(GEP); 1222 Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); 1223 Value *Offsets; 1224 unsigned Scale; 1225 Value *Base = foldGEP(GEP, Offsets, Scale, Builder); 1226 // We only want to merge the geps if there is a real chance that they can be 1227 // used by an MVE gather; thus the offset has to have the correct size 1228 // (always i32 if it is not of vector type) and the base has to be a 1229 // pointer. 1230 if (Offsets && Base && Base != GEP) { 1231 assert(Scale == 1 && "Expected to fold GEP to a scale of 1"); 1232 Type *BaseTy = Builder.getPtrTy(); 1233 if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType())) 1234 BaseTy = FixedVectorType::get(BaseTy, VecTy); 1235 GetElementPtrInst *NewAddress = GetElementPtrInst::Create( 1236 Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets, 1237 "gep.merged", GEP->getIterator()); 1238 LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP 1239 << "\n new : " << *NewAddress << "\n"); 1240 GEP->replaceAllUsesWith( 1241 Builder.CreateBitCast(NewAddress, GEP->getType())); 1242 GEP = NewAddress; 1243 Changed = true; 1244 } 1245 } 1246 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); 1247 return Changed; 1248 } 1249 1250 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 1251 if (!EnableMaskedGatherScatters) 1252 return false; 1253 auto &TPC = getAnalysis<TargetPassConfig>(); 1254 auto &TM = TPC.getTM<TargetMachine>(); 1255 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1256 if (!ST->hasMVEIntegerOps()) 1257 return false; 1258 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1259 DL = &F.getDataLayout(); 1260 SmallVector<IntrinsicInst *, 4> Gathers; 1261 SmallVector<IntrinsicInst *, 4> Scatters; 1262 1263 bool Changed = false; 1264 1265 for (BasicBlock &BB : F) { 1266 Changed |= SimplifyInstructionsInBlock(&BB); 1267 1268 for (Instruction &I : BB) { 1269 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 1270 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 1271 isa<FixedVectorType>(II->getType())) { 1272 Gathers.push_back(II); 1273 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); 1274 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 1275 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 1276 Scatters.push_back(II); 1277 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); 1278 } 1279 } 1280 } 1281 for (IntrinsicInst *I : Gathers) { 1282 Instruction *L = lowerGather(I); 1283 if (L == nullptr) 1284 continue; 1285 1286 // Get rid of any now dead instructions 1287 SimplifyInstructionsInBlock(L->getParent()); 1288 Changed = true; 1289 } 1290 1291 for (IntrinsicInst *I : Scatters) { 1292 Instruction *S = lowerScatter(I); 1293 if (S == nullptr) 1294 continue; 1295 1296 // Get rid of any now dead instructions 1297 SimplifyInstructionsInBlock(S->getParent()); 1298 Changed = true; 1299 } 1300 return Changed; 1301 } 1302