1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #include <algorithm>
43 #include <cassert>
44
45 using namespace llvm;
46
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48
49 cl::opt<bool> EnableMaskedGatherScatters(
50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
51 cl::desc("Enable the generation of masked gathers and scatters"));
52
53 namespace {
54
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57 static char ID; // Pass identification, replacement for typeid
58
MVEGatherScatterLowering()59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61 }
62
63 bool runOnFunction(Function &F) override;
64
getPassName() const65 StringRef getPassName() const override {
66 return "MVE gather/scatter lowering";
67 }
68
getAnalysisUsage(AnalysisUsage & AU) const69 void getAnalysisUsage(AnalysisUsage &AU) const override {
70 AU.setPreservesCFG();
71 AU.addRequired<TargetPassConfig>();
72 AU.addRequired<LoopInfoWrapperPass>();
73 FunctionPass::getAnalysisUsage(AU);
74 }
75
76 private:
77 LoopInfo *LI = nullptr;
78
79 // Check this is a valid gather with correct alignment
80 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81 Align Alignment);
82 // Check whether Ptr is hidden behind a bitcast and look through it
83 void lookThroughBitcast(Value *&Ptr);
84 // Check for a getelementptr and deduce base and offsets from it, on success
85 // returning the base directly and the offsets indirectly using the Offsets
86 // argument
87 Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP,
88 IRBuilder<> &Builder);
89 // Compute the scale of this gather/scatter instruction
90 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91 // If the value is a constant, or derived from constants via additions
92 // and multilications, return its numeric value
93 Optional<int64_t> getIfConst(const Value *V);
94 // If Inst is an add instruction, check whether one summand is a
95 // constant. If so, scale this constant and return it together with
96 // the other summand.
97 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98
99 Value *lowerGather(IntrinsicInst *I);
100 // Create a gather from a base + vector of offsets
101 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102 Instruction *&Root, IRBuilder<> &Builder);
103 // Create a gather from a vector of pointers
104 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105 IRBuilder<> &Builder, int64_t Increment = 0);
106 // Create an incrementing gather from a vector of pointers
107 Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
108 IRBuilder<> &Builder,
109 int64_t Increment = 0);
110
111 Value *lowerScatter(IntrinsicInst *I);
112 // Create a scatter to a base + vector of offsets
113 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
114 IRBuilder<> &Builder);
115 // Create a scatter to a vector of pointers
116 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
117 IRBuilder<> &Builder,
118 int64_t Increment = 0);
119 // Create an incrementing scatter from a vector of pointers
120 Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
121 IRBuilder<> &Builder,
122 int64_t Increment = 0);
123
124 // QI gathers and scatters can increment their offsets on their own if
125 // the increment is a constant value (digit)
126 Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127 Value *Ptr, GetElementPtrInst *GEP,
128 IRBuilder<> &Builder);
129 // QI gathers/scatters can increment their offsets on their own if the
130 // increment is a constant value (digit) - this creates a writeback QI
131 // gather/scatter
132 Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133 Value *Ptr, unsigned TypeScale,
134 IRBuilder<> &Builder);
135
136 // Optimise the base and offsets of the given address
137 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
138 // Try to fold consecutive geps together into one
139 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
140 // Check whether these offsets could be moved out of the loop they're in
141 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
142 // Pushes the given add out of the loop
143 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
144 // Pushes the given mul out of the loop
145 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
146 Value *OffsSecondOperand, unsigned LoopIncrement,
147 IRBuilder<> &Builder);
148 };
149
150 } // end anonymous namespace
151
152 char MVEGatherScatterLowering::ID = 0;
153
154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
155 "MVE gather/scattering lowering pass", false, false)
156
createMVEGatherScatterLoweringPass()157 Pass *llvm::createMVEGatherScatterLoweringPass() {
158 return new MVEGatherScatterLowering();
159 }
160
isLegalTypeAndAlignment(unsigned NumElements,unsigned ElemSize,Align Alignment)161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
162 unsigned ElemSize,
163 Align Alignment) {
164 if (((NumElements == 4 &&
165 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
166 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
167 (NumElements == 16 && ElemSize == 8)) &&
168 Alignment >= ElemSize / 8)
169 return true;
170 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
171 << "valid alignment or vector type \n");
172 return false;
173 }
174
checkOffsetSize(Value * Offsets,unsigned TargetElemCount)175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
176 // Offsets that are not of type <N x i32> are sign extended by the
177 // getelementptr instruction, and MVE gathers/scatters treat the offset as
178 // unsigned. Thus, if the element size is smaller than 32, we can only allow
179 // positive offsets - i.e., the offsets are not allowed to be variables we
180 // can't look into.
181 // Additionally, <N x i32> offsets have to either originate from a zext of a
182 // vector with element types smaller or equal the type of the gather we're
183 // looking at, or consist of constants that we can check are small enough
184 // to fit into the gather type.
185 // Thus we check that 0 < value < 2^TargetElemSize.
186 unsigned TargetElemSize = 128 / TargetElemCount;
187 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
188 ->getElementType()
189 ->getScalarSizeInBits();
190 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
191 Constant *ConstOff = dyn_cast<Constant>(Offsets);
192 if (!ConstOff)
193 return false;
194 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
195 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
196 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
197 if (!OConst)
198 return false;
199 int SExtValue = OConst->getSExtValue();
200 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
201 return false;
202 return true;
203 };
204 if (isa<FixedVectorType>(ConstOff->getType())) {
205 for (unsigned i = 0; i < TargetElemCount; i++) {
206 if (!CheckValueSize(ConstOff->getAggregateElement(i)))
207 return false;
208 }
209 } else {
210 if (!CheckValueSize(ConstOff))
211 return false;
212 }
213 }
214 return true;
215 }
216
checkGEP(Value * & Offsets,FixedVectorType * Ty,GetElementPtrInst * GEP,IRBuilder<> & Builder)217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
218 GetElementPtrInst *GEP,
219 IRBuilder<> &Builder) {
220 if (!GEP) {
221 LLVM_DEBUG(
222 dbgs() << "masked gathers/scatters: no getelementpointer found\n");
223 return nullptr;
224 }
225 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
226 << " Looking at intrinsic for base + vector of offsets\n");
227 Value *GEPPtr = GEP->getPointerOperand();
228 Offsets = GEP->getOperand(1);
229 if (GEPPtr->getType()->isVectorTy() ||
230 !isa<FixedVectorType>(Offsets->getType()))
231 return nullptr;
232
233 if (GEP->getNumOperands() != 2) {
234 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
235 << " operands. Expanding.\n");
236 return nullptr;
237 }
238 Offsets = GEP->getOperand(1);
239 unsigned OffsetsElemCount =
240 cast<FixedVectorType>(Offsets->getType())->getNumElements();
241 // Paranoid check whether the number of parallel lanes is the same
242 assert(Ty->getNumElements() == OffsetsElemCount);
243
244 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
245 if (ZextOffs)
246 Offsets = ZextOffs->getOperand(0);
247 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
248
249 // If the offsets are already being zext-ed to <N x i32>, that relieves us of
250 // having to make sure that they won't overflow.
251 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
252 ->getElementType()
253 ->getScalarSizeInBits() != 32)
254 if (!checkOffsetSize(Offsets, OffsetsElemCount))
255 return nullptr;
256
257 // The offset sizes have been checked; if any truncating or zext-ing is
258 // required to fix them, do that now
259 if (Ty != Offsets->getType()) {
260 if ((Ty->getElementType()->getScalarSizeInBits() <
261 OffsetType->getElementType()->getScalarSizeInBits())) {
262 Offsets = Builder.CreateTrunc(Offsets, Ty);
263 } else {
264 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
265 }
266 }
267 // If none of the checks failed, return the gep's base pointer
268 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
269 return GEPPtr;
270 }
271
lookThroughBitcast(Value * & Ptr)272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
273 // Look through bitcast instruction if #elements is the same
274 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
275 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
276 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
277 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
278 LLVM_DEBUG(
279 dbgs() << "masked gathers/scatters: looking through bitcast\n");
280 Ptr = BitCast->getOperand(0);
281 }
282 }
283 }
284
computeScale(unsigned GEPElemSize,unsigned MemoryElemSize)285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
286 unsigned MemoryElemSize) {
287 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
288 // or a 8bit, 16bit or 32bit load/store scaled by 1
289 if (GEPElemSize == 32 && MemoryElemSize == 32)
290 return 2;
291 else if (GEPElemSize == 16 && MemoryElemSize == 16)
292 return 1;
293 else if (GEPElemSize == 8)
294 return 0;
295 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
296 << "create intrinsic\n");
297 return -1;
298 }
299
getIfConst(const Value * V)300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
301 const Constant *C = dyn_cast<Constant>(V);
302 if (C != nullptr)
303 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
304 if (!isa<Instruction>(V))
305 return Optional<int64_t>{};
306
307 const Instruction *I = cast<Instruction>(V);
308 if (I->getOpcode() == Instruction::Add ||
309 I->getOpcode() == Instruction::Mul) {
310 Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
311 Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
312 if (!Op0 || !Op1)
313 return Optional<int64_t>{};
314 if (I->getOpcode() == Instruction::Add)
315 return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
316 if (I->getOpcode() == Instruction::Mul)
317 return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
318 }
319 return Optional<int64_t>{};
320 }
321
322 std::pair<Value *, int64_t>
getVarAndConst(Value * Inst,int TypeScale)323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
324 std::pair<Value *, int64_t> ReturnFalse =
325 std::pair<Value *, int64_t>(nullptr, 0);
326 // At this point, the instruction we're looking at must be an add or we
327 // bail out
328 Instruction *Add = dyn_cast<Instruction>(Inst);
329 if (Add == nullptr || Add->getOpcode() != Instruction::Add)
330 return ReturnFalse;
331
332 Value *Summand;
333 Optional<int64_t> Const;
334 // Find out which operand the value that is increased is
335 if ((Const = getIfConst(Add->getOperand(0))))
336 Summand = Add->getOperand(1);
337 else if ((Const = getIfConst(Add->getOperand(1))))
338 Summand = Add->getOperand(0);
339 else
340 return ReturnFalse;
341
342 // Check that the constant is small enough for an incrementing gather
343 int64_t Immediate = Const.getValue() << TypeScale;
344 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
345 return ReturnFalse;
346
347 return std::pair<Value *, int64_t>(Summand, Immediate);
348 }
349
lowerGather(IntrinsicInst * I)350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
351 using namespace PatternMatch;
352 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
353 << *I << "\n");
354
355 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
356 // Attempt to turn the masked gather in I into a MVE intrinsic
357 // Potentially optimising the addressing modes as we do so.
358 auto *Ty = cast<FixedVectorType>(I->getType());
359 Value *Ptr = I->getArgOperand(0);
360 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
361 Value *Mask = I->getArgOperand(2);
362 Value *PassThru = I->getArgOperand(3);
363
364 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
365 Alignment))
366 return nullptr;
367 lookThroughBitcast(Ptr);
368 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
369
370 IRBuilder<> Builder(I->getContext());
371 Builder.SetInsertPoint(I);
372 Builder.SetCurrentDebugLocation(I->getDebugLoc());
373
374 Instruction *Root = I;
375 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
376 if (!Load)
377 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
378 if (!Load)
379 return nullptr;
380
381 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
382 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
383 << "creating select\n");
384 Load = Builder.CreateSelect(Mask, Load, PassThru);
385 }
386
387 Root->replaceAllUsesWith(Load);
388 Root->eraseFromParent();
389 if (Root != I)
390 // If this was an extending gather, we need to get rid of the sext/zext
391 // sext/zext as well as of the gather itself
392 I->eraseFromParent();
393
394 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
395 << *Load << "\n");
396 return Load;
397 }
398
tryCreateMaskedGatherBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)399 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
400 Value *Ptr,
401 IRBuilder<> &Builder,
402 int64_t Increment) {
403 using namespace PatternMatch;
404 auto *Ty = cast<FixedVectorType>(I->getType());
405 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
406 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
407 // Can't build an intrinsic for this
408 return nullptr;
409 Value *Mask = I->getArgOperand(2);
410 if (match(Mask, m_One()))
411 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
412 {Ty, Ptr->getType()},
413 {Ptr, Builder.getInt32(Increment)});
414 else
415 return Builder.CreateIntrinsic(
416 Intrinsic::arm_mve_vldr_gather_base_predicated,
417 {Ty, Ptr->getType(), Mask->getType()},
418 {Ptr, Builder.getInt32(Increment), Mask});
419 }
420
tryCreateMaskedGatherBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)421 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
422 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
423 using namespace PatternMatch;
424 auto *Ty = cast<FixedVectorType>(I->getType());
425 LLVM_DEBUG(
426 dbgs()
427 << "masked gathers: loading from vector of pointers with writeback\n");
428 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
429 // Can't build an intrinsic for this
430 return nullptr;
431 Value *Mask = I->getArgOperand(2);
432 if (match(Mask, m_One()))
433 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
434 {Ty, Ptr->getType()},
435 {Ptr, Builder.getInt32(Increment)});
436 else
437 return Builder.CreateIntrinsic(
438 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
439 {Ty, Ptr->getType(), Mask->getType()},
440 {Ptr, Builder.getInt32(Increment), Mask});
441 }
442
tryCreateMaskedGatherOffset(IntrinsicInst * I,Value * Ptr,Instruction * & Root,IRBuilder<> & Builder)443 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
444 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
445 using namespace PatternMatch;
446
447 Type *OriginalTy = I->getType();
448 Type *ResultTy = OriginalTy;
449
450 unsigned Unsigned = 1;
451 // The size of the gather was already checked in isLegalTypeAndAlignment;
452 // if it was not a full vector width an appropriate extend should follow.
453 auto *Extend = Root;
454 if (OriginalTy->getPrimitiveSizeInBits() < 128) {
455 // Only transform gathers with exactly one use
456 if (!I->hasOneUse())
457 return nullptr;
458
459 // The correct root to replace is not the CallInst itself, but the
460 // instruction which extends it
461 Extend = cast<Instruction>(*I->users().begin());
462 if (isa<SExtInst>(Extend)) {
463 Unsigned = 0;
464 } else if (!isa<ZExtInst>(Extend)) {
465 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
466 << "Expanding\n");
467 return nullptr;
468 }
469 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
470 ResultTy = Extend->getType();
471 // The final size of the gather must be a full vector width
472 if (ResultTy->getPrimitiveSizeInBits() != 128) {
473 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
474 << "Expanding\n");
475 return nullptr;
476 }
477 }
478
479 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
480 Value *Offsets;
481 Value *BasePtr =
482 checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
483 if (!BasePtr)
484 return nullptr;
485 // Check whether the offset is a constant increment that could be merged into
486 // a QI gather
487 Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
488 if (Load)
489 return Load;
490
491 int Scale = computeScale(
492 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
493 OriginalTy->getScalarSizeInBits());
494 if (Scale == -1)
495 return nullptr;
496 Root = Extend;
497
498 Value *Mask = I->getArgOperand(2);
499 if (!match(Mask, m_One()))
500 return Builder.CreateIntrinsic(
501 Intrinsic::arm_mve_vldr_gather_offset_predicated,
502 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
503 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
504 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
505 else
506 return Builder.CreateIntrinsic(
507 Intrinsic::arm_mve_vldr_gather_offset,
508 {ResultTy, BasePtr->getType(), Offsets->getType()},
509 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
510 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
511 }
512
lowerScatter(IntrinsicInst * I)513 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
514 using namespace PatternMatch;
515 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
516 << *I << "\n");
517
518 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
519 // Attempt to turn the masked scatter in I into a MVE intrinsic
520 // Potentially optimising the addressing modes as we do so.
521 Value *Input = I->getArgOperand(0);
522 Value *Ptr = I->getArgOperand(1);
523 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
524 auto *Ty = cast<FixedVectorType>(Input->getType());
525
526 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
527 Alignment))
528 return nullptr;
529
530 lookThroughBitcast(Ptr);
531 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
532
533 IRBuilder<> Builder(I->getContext());
534 Builder.SetInsertPoint(I);
535 Builder.SetCurrentDebugLocation(I->getDebugLoc());
536
537 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
538 if (!Store)
539 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
540 if (!Store)
541 return nullptr;
542
543 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
544 << *Store << "\n");
545 I->eraseFromParent();
546 return Store;
547 }
548
tryCreateMaskedScatterBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)549 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
550 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
551 using namespace PatternMatch;
552 Value *Input = I->getArgOperand(0);
553 auto *Ty = cast<FixedVectorType>(Input->getType());
554 // Only QR variants allow truncating
555 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
556 // Can't build an intrinsic for this
557 return nullptr;
558 }
559 Value *Mask = I->getArgOperand(3);
560 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
561 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
562 if (match(Mask, m_One()))
563 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
564 {Ptr->getType(), Input->getType()},
565 {Ptr, Builder.getInt32(Increment), Input});
566 else
567 return Builder.CreateIntrinsic(
568 Intrinsic::arm_mve_vstr_scatter_base_predicated,
569 {Ptr->getType(), Input->getType(), Mask->getType()},
570 {Ptr, Builder.getInt32(Increment), Input, Mask});
571 }
572
tryCreateMaskedScatterBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)573 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
574 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
575 using namespace PatternMatch;
576 Value *Input = I->getArgOperand(0);
577 auto *Ty = cast<FixedVectorType>(Input->getType());
578 LLVM_DEBUG(
579 dbgs()
580 << "masked scatters: storing to a vector of pointers with writeback\n");
581 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
582 // Can't build an intrinsic for this
583 return nullptr;
584 Value *Mask = I->getArgOperand(3);
585 if (match(Mask, m_One()))
586 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
587 {Ptr->getType(), Input->getType()},
588 {Ptr, Builder.getInt32(Increment), Input});
589 else
590 return Builder.CreateIntrinsic(
591 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
592 {Ptr->getType(), Input->getType(), Mask->getType()},
593 {Ptr, Builder.getInt32(Increment), Input, Mask});
594 }
595
tryCreateMaskedScatterOffset(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder)596 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
597 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
598 using namespace PatternMatch;
599 Value *Input = I->getArgOperand(0);
600 Value *Mask = I->getArgOperand(3);
601 Type *InputTy = Input->getType();
602 Type *MemoryTy = InputTy;
603 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
604 << " to base + vector of offsets\n");
605 // If the input has been truncated, try to integrate that trunc into the
606 // scatter instruction (we don't care about alignment here)
607 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
608 Value *PreTrunc = Trunc->getOperand(0);
609 Type *PreTruncTy = PreTrunc->getType();
610 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
611 Input = PreTrunc;
612 InputTy = PreTruncTy;
613 }
614 }
615 if (InputTy->getPrimitiveSizeInBits() != 128) {
616 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
617 "non-standard input types. Expanding.\n");
618 return nullptr;
619 }
620
621 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
622 Value *Offsets;
623 Value *BasePtr =
624 checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
625 if (!BasePtr)
626 return nullptr;
627 // Check whether the offset is a constant increment that could be merged into
628 // a QI gather
629 Value *Store =
630 tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
631 if (Store)
632 return Store;
633 int Scale = computeScale(
634 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
635 MemoryTy->getScalarSizeInBits());
636 if (Scale == -1)
637 return nullptr;
638
639 if (!match(Mask, m_One()))
640 return Builder.CreateIntrinsic(
641 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
642 {BasePtr->getType(), Offsets->getType(), Input->getType(),
643 Mask->getType()},
644 {BasePtr, Offsets, Input,
645 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
646 Builder.getInt32(Scale), Mask});
647 else
648 return Builder.CreateIntrinsic(
649 Intrinsic::arm_mve_vstr_scatter_offset,
650 {BasePtr->getType(), Offsets->getType(), Input->getType()},
651 {BasePtr, Offsets, Input,
652 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
653 Builder.getInt32(Scale)});
654 }
655
tryCreateIncrementingGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,GetElementPtrInst * GEP,IRBuilder<> & Builder)656 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
657 IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
658 IRBuilder<> &Builder) {
659 FixedVectorType *Ty;
660 if (I->getIntrinsicID() == Intrinsic::masked_gather)
661 Ty = cast<FixedVectorType>(I->getType());
662 else
663 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
664 // Incrementing gathers only exist for v4i32
665 if (Ty->getNumElements() != 4 ||
666 Ty->getScalarSizeInBits() != 32)
667 return nullptr;
668 Loop *L = LI->getLoopFor(I->getParent());
669 if (L == nullptr)
670 // Incrementing gathers are not beneficial outside of a loop
671 return nullptr;
672 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
673 "wb gather/scatter\n");
674
675 // The gep was in charge of making sure the offsets are scaled correctly
676 // - calculate that factor so it can be applied by hand
677 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
678 int TypeScale =
679 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
680 DT.getTypeSizeInBits(GEP->getType()) /
681 cast<FixedVectorType>(GEP->getType())->getNumElements());
682 if (TypeScale == -1)
683 return nullptr;
684
685 if (GEP->hasOneUse()) {
686 // Only in this case do we want to build a wb gather, because the wb will
687 // change the phi which does affect other users of the gep (which will still
688 // be using the phi in the old way)
689 Value *Load =
690 tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
691 if (Load != nullptr)
692 return Load;
693 }
694 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
695 "non-wb gather/scatter\n");
696
697 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
698 if (Add.first == nullptr)
699 return nullptr;
700 Value *OffsetsIncoming = Add.first;
701 int64_t Immediate = Add.second;
702
703 // Make sure the offsets are scaled correctly
704 Instruction *ScaledOffsets = BinaryOperator::Create(
705 Instruction::Shl, OffsetsIncoming,
706 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
707 "ScaledIndex", I);
708 // Add the base to the offsets
709 OffsetsIncoming = BinaryOperator::Create(
710 Instruction::Add, ScaledOffsets,
711 Builder.CreateVectorSplat(
712 Ty->getNumElements(),
713 Builder.CreatePtrToInt(
714 BasePtr,
715 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
716 "StartIndex", I);
717
718 if (I->getIntrinsicID() == Intrinsic::masked_gather)
719 return cast<IntrinsicInst>(
720 tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
721 else
722 return cast<IntrinsicInst>(
723 tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
724 }
725
tryCreateIncrementingWBGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,unsigned TypeScale,IRBuilder<> & Builder)726 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
727 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
728 IRBuilder<> &Builder) {
729 // Check whether this gather's offset is incremented by a constant - if so,
730 // and the load is of the right type, we can merge this into a QI gather
731 Loop *L = LI->getLoopFor(I->getParent());
732 // Offsets that are worth merging into this instruction will be incremented
733 // by a constant, thus we're looking for an add of a phi and a constant
734 PHINode *Phi = dyn_cast<PHINode>(Offsets);
735 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
736 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
737 // No phi means no IV to write back to; if there is a phi, we expect it
738 // to have exactly two incoming values; the only phis we are interested in
739 // will be loop IV's and have exactly two uses, one in their increment and
740 // one in the gather's gep
741 return nullptr;
742
743 unsigned IncrementIndex =
744 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
745 // Look through the phi to the phi increment
746 Offsets = Phi->getIncomingValue(IncrementIndex);
747
748 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
749 if (Add.first == nullptr)
750 return nullptr;
751 Value *OffsetsIncoming = Add.first;
752 int64_t Immediate = Add.second;
753 if (OffsetsIncoming != Phi)
754 // Then the increment we are looking at is not an increment of the
755 // induction variable, and we don't want to do a writeback
756 return nullptr;
757
758 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
759 unsigned NumElems =
760 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
761
762 // Make sure the offsets are scaled correctly
763 Instruction *ScaledOffsets = BinaryOperator::Create(
764 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
765 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
766 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
767 // Add the base to the offsets
768 OffsetsIncoming = BinaryOperator::Create(
769 Instruction::Add, ScaledOffsets,
770 Builder.CreateVectorSplat(
771 NumElems,
772 Builder.CreatePtrToInt(
773 BasePtr,
774 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
775 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
776 // The gather is pre-incrementing
777 OffsetsIncoming = BinaryOperator::Create(
778 Instruction::Sub, OffsetsIncoming,
779 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
780 "PreIncrementStartIndex",
781 &Phi->getIncomingBlock(1 - IncrementIndex)->back());
782 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
783
784 Builder.SetInsertPoint(I);
785
786 Value *EndResult;
787 Value *NewInduction;
788 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
789 // Build the incrementing gather
790 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
791 // One value to be handed to whoever uses the gather, one is the loop
792 // increment
793 EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
794 NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
795 } else {
796 // Build the incrementing scatter
797 NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
798 EndResult = NewInduction;
799 }
800 Instruction *AddInst = cast<Instruction>(Offsets);
801 AddInst->replaceAllUsesWith(NewInduction);
802 AddInst->eraseFromParent();
803 Phi->setIncomingValue(IncrementIndex, NewInduction);
804
805 return EndResult;
806 }
807
pushOutAdd(PHINode * & Phi,Value * OffsSecondOperand,unsigned StartIndex)808 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
809 Value *OffsSecondOperand,
810 unsigned StartIndex) {
811 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
812 Instruction *InsertionPoint =
813 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
814 // Initialize the phi with a vector that contains a sum of the constants
815 Instruction *NewIndex = BinaryOperator::Create(
816 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
817 "PushedOutAdd", InsertionPoint);
818 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
819
820 // Order such that start index comes first (this reduces mov's)
821 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
822 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
823 Phi->getIncomingBlock(IncrementIndex));
824 Phi->removeIncomingValue(IncrementIndex);
825 Phi->removeIncomingValue(StartIndex);
826 }
827
pushOutMul(PHINode * & Phi,Value * IncrementPerRound,Value * OffsSecondOperand,unsigned LoopIncrement,IRBuilder<> & Builder)828 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
829 Value *IncrementPerRound,
830 Value *OffsSecondOperand,
831 unsigned LoopIncrement,
832 IRBuilder<> &Builder) {
833 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
834
835 // Create a new scalar add outside of the loop and transform it to a splat
836 // by which loop variable can be incremented
837 Instruction *InsertionPoint = &cast<Instruction>(
838 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
839
840 // Create a new index
841 Value *StartIndex = BinaryOperator::Create(
842 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
843 OffsSecondOperand, "PushedOutMul", InsertionPoint);
844
845 Instruction *Product =
846 BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
847 OffsSecondOperand, "Product", InsertionPoint);
848 // Increment NewIndex by Product instead of the multiplication
849 Instruction *NewIncrement = BinaryOperator::Create(
850 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
851 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
852 .getPrevNode());
853
854 Phi->addIncoming(StartIndex,
855 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
856 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
857 Phi->removeIncomingValue((unsigned)0);
858 Phi->removeIncomingValue((unsigned)0);
859 }
860
861 // Check whether all usages of this instruction are as offsets of
862 // gathers/scatters or simple arithmetics only used by gathers/scatters
hasAllGatScatUsers(Instruction * I)863 static bool hasAllGatScatUsers(Instruction *I) {
864 if (I->hasNUses(0)) {
865 return false;
866 }
867 bool Gatscat = true;
868 for (User *U : I->users()) {
869 if (!isa<Instruction>(U))
870 return false;
871 if (isa<GetElementPtrInst>(U) ||
872 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
873 return Gatscat;
874 } else {
875 unsigned OpCode = cast<Instruction>(U)->getOpcode();
876 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
877 hasAllGatScatUsers(cast<Instruction>(U))) {
878 continue;
879 }
880 return false;
881 }
882 }
883 return Gatscat;
884 }
885
optimiseOffsets(Value * Offsets,BasicBlock * BB,LoopInfo * LI)886 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
887 LoopInfo *LI) {
888 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
889 << *Offsets << "\n");
890 // Optimise the addresses of gathers/scatters by moving invariant
891 // calculations out of the loop
892 if (!isa<Instruction>(Offsets))
893 return false;
894 Instruction *Offs = cast<Instruction>(Offsets);
895 if (Offs->getOpcode() != Instruction::Add &&
896 Offs->getOpcode() != Instruction::Mul)
897 return false;
898 Loop *L = LI->getLoopFor(BB);
899 if (L == nullptr)
900 return false;
901 if (!Offs->hasOneUse()) {
902 if (!hasAllGatScatUsers(Offs))
903 return false;
904 }
905
906 // Find out which, if any, operand of the instruction
907 // is a phi node
908 PHINode *Phi;
909 int OffsSecondOp;
910 if (isa<PHINode>(Offs->getOperand(0))) {
911 Phi = cast<PHINode>(Offs->getOperand(0));
912 OffsSecondOp = 1;
913 } else if (isa<PHINode>(Offs->getOperand(1))) {
914 Phi = cast<PHINode>(Offs->getOperand(1));
915 OffsSecondOp = 0;
916 } else {
917 bool Changed = true;
918 if (isa<Instruction>(Offs->getOperand(0)) &&
919 L->contains(cast<Instruction>(Offs->getOperand(0))))
920 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
921 if (isa<Instruction>(Offs->getOperand(1)) &&
922 L->contains(cast<Instruction>(Offs->getOperand(1))))
923 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
924 if (!Changed) {
925 return false;
926 } else {
927 if (isa<PHINode>(Offs->getOperand(0))) {
928 Phi = cast<PHINode>(Offs->getOperand(0));
929 OffsSecondOp = 1;
930 } else if (isa<PHINode>(Offs->getOperand(1))) {
931 Phi = cast<PHINode>(Offs->getOperand(1));
932 OffsSecondOp = 0;
933 } else {
934 return false;
935 }
936 }
937 }
938 // A phi node we want to perform this function on should be from the
939 // loop header, and shouldn't have more than 2 incoming values
940 if (Phi->getParent() != L->getHeader() ||
941 Phi->getNumIncomingValues() != 2)
942 return false;
943
944 // The phi must be an induction variable
945 int IncrementingBlock = -1;
946
947 for (int i = 0; i < 2; i++)
948 if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i)))
949 if (Op->getOpcode() == Instruction::Add &&
950 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
951 IncrementingBlock = i;
952 if (IncrementingBlock == -1)
953 return false;
954
955 Instruction *IncInstruction =
956 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
957
958 // If the phi is not used by anything else, we can just adapt it when
959 // replacing the instruction; if it is, we'll have to duplicate it
960 PHINode *NewPhi;
961 Value *IncrementPerRound = IncInstruction->getOperand(
962 (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
963
964 // Get the value that is added to/multiplied with the phi
965 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
966
967 if (IncrementPerRound->getType() != OffsSecondOperand->getType())
968 // Something has gone wrong, abort
969 return false;
970
971 // Only proceed if the increment per round is a constant or an instruction
972 // which does not originate from within the loop
973 if (!isa<Constant>(IncrementPerRound) &&
974 !(isa<Instruction>(IncrementPerRound) &&
975 !L->contains(cast<Instruction>(IncrementPerRound))))
976 return false;
977
978 if (Phi->getNumUses() == 2) {
979 // No other users -> reuse existing phi (One user is the instruction
980 // we're looking at, the other is the phi increment)
981 if (IncInstruction->getNumUses() != 1) {
982 // If the incrementing instruction does have more users than
983 // our phi, we need to copy it
984 IncInstruction = BinaryOperator::Create(
985 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
986 IncrementPerRound, "LoopIncrement", IncInstruction);
987 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
988 }
989 NewPhi = Phi;
990 } else {
991 // There are other users -> create a new phi
992 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
993 std::vector<Value *> Increases;
994 // Copy the incoming values of the old phi
995 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
996 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
997 IncInstruction = BinaryOperator::Create(
998 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
999 IncrementPerRound, "LoopIncrement", IncInstruction);
1000 NewPhi->addIncoming(IncInstruction,
1001 Phi->getIncomingBlock(IncrementingBlock));
1002 IncrementingBlock = 1;
1003 }
1004
1005 IRBuilder<> Builder(BB->getContext());
1006 Builder.SetInsertPoint(Phi);
1007 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1008
1009 switch (Offs->getOpcode()) {
1010 case Instruction::Add:
1011 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1012 break;
1013 case Instruction::Mul:
1014 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1015 Builder);
1016 break;
1017 default:
1018 return false;
1019 }
1020 LLVM_DEBUG(
1021 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
1022
1023 // The instruction has now been "absorbed" into the phi value
1024 Offs->replaceAllUsesWith(NewPhi);
1025 if (Offs->hasNUses(0))
1026 Offs->eraseFromParent();
1027 // Clean up the old increment in case it's unused because we built a new
1028 // one
1029 if (IncInstruction->hasNUses(0))
1030 IncInstruction->eraseFromParent();
1031
1032 return true;
1033 }
1034
CheckAndCreateOffsetAdd(Value * X,Value * Y,Value * GEP,IRBuilder<> & Builder)1035 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
1036 IRBuilder<> &Builder) {
1037 // Splat the non-vector value to a vector of the given type - if the value is
1038 // a constant (and its value isn't too big), we can even use this opportunity
1039 // to scale it to the size of the vector elements
1040 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1041 ConstantInt *Const;
1042 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1043 VT->getElementType() != NonVectorVal->getType()) {
1044 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1045 uint64_t N = Const->getZExtValue();
1046 if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1047 NonVectorVal = Builder.CreateVectorSplat(
1048 VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1049 return;
1050 }
1051 }
1052 NonVectorVal =
1053 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1054 };
1055
1056 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1057 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1058 // If one of X, Y is not a vector, we have to splat it in order
1059 // to add the two of them.
1060 if (XElType && !YElType) {
1061 FixSummands(XElType, Y);
1062 YElType = cast<FixedVectorType>(Y->getType());
1063 } else if (YElType && !XElType) {
1064 FixSummands(YElType, X);
1065 XElType = cast<FixedVectorType>(X->getType());
1066 }
1067 assert(XElType && YElType && "Unknown vector types");
1068 // Check that the summands are of compatible types
1069 if (XElType != YElType) {
1070 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1071 return nullptr;
1072 }
1073
1074 if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1075 // Check that by adding the vectors we do not accidentally
1076 // create an overflow
1077 Constant *ConstX = dyn_cast<Constant>(X);
1078 Constant *ConstY = dyn_cast<Constant>(Y);
1079 if (!ConstX || !ConstY)
1080 return nullptr;
1081 unsigned TargetElemSize = 128 / XElType->getNumElements();
1082 for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1083 ConstantInt *ConstXEl =
1084 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1085 ConstantInt *ConstYEl =
1086 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1087 if (!ConstXEl || !ConstYEl ||
1088 ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1089 (unsigned)(1 << (TargetElemSize - 1)))
1090 return nullptr;
1091 }
1092 }
1093
1094 Value *Add = Builder.CreateAdd(X, Y);
1095
1096 FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1097 if (checkOffsetSize(Add, GEPType->getNumElements()))
1098 return Add;
1099 else
1100 return nullptr;
1101 }
1102
foldGEP(GetElementPtrInst * GEP,Value * & Offsets,IRBuilder<> & Builder)1103 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1104 Value *&Offsets,
1105 IRBuilder<> &Builder) {
1106 Value *GEPPtr = GEP->getPointerOperand();
1107 Offsets = GEP->getOperand(1);
1108 // We only merge geps with constant offsets, because only for those
1109 // we can make sure that we do not cause an overflow
1110 if (!isa<Constant>(Offsets))
1111 return nullptr;
1112 GetElementPtrInst *BaseGEP;
1113 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1114 // Merge the two geps into one
1115 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1116 if (!BaseBasePtr)
1117 return nullptr;
1118 Offsets =
1119 CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1120 if (Offsets == nullptr)
1121 return nullptr;
1122 return BaseBasePtr;
1123 }
1124 return GEPPtr;
1125 }
1126
optimiseAddress(Value * Address,BasicBlock * BB,LoopInfo * LI)1127 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1128 LoopInfo *LI) {
1129 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1130 if (!GEP)
1131 return false;
1132 bool Changed = false;
1133 if (GEP->hasOneUse() &&
1134 dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1135 IRBuilder<> Builder(GEP->getContext());
1136 Builder.SetInsertPoint(GEP);
1137 Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1138 Value *Offsets;
1139 Value *Base = foldGEP(GEP, Offsets, Builder);
1140 // We only want to merge the geps if there is a real chance that they can be
1141 // used by an MVE gather; thus the offset has to have the correct size
1142 // (always i32 if it is not of vector type) and the base has to be a
1143 // pointer.
1144 if (Offsets && Base && Base != GEP) {
1145 PointerType *BaseType = cast<PointerType>(Base->getType());
1146 GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1147 BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP);
1148 GEP->replaceAllUsesWith(NewAddress);
1149 GEP = NewAddress;
1150 Changed = true;
1151 }
1152 }
1153 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1154 return Changed;
1155 }
1156
runOnFunction(Function & F)1157 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1158 if (!EnableMaskedGatherScatters)
1159 return false;
1160 auto &TPC = getAnalysis<TargetPassConfig>();
1161 auto &TM = TPC.getTM<TargetMachine>();
1162 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1163 if (!ST->hasMVEIntegerOps())
1164 return false;
1165 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1166 SmallVector<IntrinsicInst *, 4> Gathers;
1167 SmallVector<IntrinsicInst *, 4> Scatters;
1168
1169 bool Changed = false;
1170
1171 for (BasicBlock &BB : F) {
1172 for (Instruction &I : BB) {
1173 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1174 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1175 isa<FixedVectorType>(II->getType())) {
1176 Gathers.push_back(II);
1177 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1178 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1179 isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1180 Scatters.push_back(II);
1181 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1182 }
1183 }
1184 }
1185 for (unsigned i = 0; i < Gathers.size(); i++) {
1186 IntrinsicInst *I = Gathers[i];
1187 Value *L = lowerGather(I);
1188 if (L == nullptr)
1189 continue;
1190
1191 // Get rid of any now dead instructions
1192 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1193 Changed = true;
1194 }
1195
1196 for (unsigned i = 0; i < Scatters.size(); i++) {
1197 IntrinsicInst *I = Scatters[i];
1198 Value *S = lowerScatter(I);
1199 if (S == nullptr)
1200 continue;
1201
1202 // Get rid of any now dead instructions
1203 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1204 Changed = true;
1205 }
1206 return Changed;
1207 }
1208