xref: /openbsd-src/gnu/llvm/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
109467b48Spatrick //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick /// \file This pass replaces accesses to kernel arguments with loads from
1009467b48Spatrick /// offsets from the kernarg base pointer.
1109467b48Spatrick //
1209467b48Spatrick //===----------------------------------------------------------------------===//
1309467b48Spatrick 
1409467b48Spatrick #include "AMDGPU.h"
1573471bf0Spatrick #include "GCNSubtarget.h"
1609467b48Spatrick #include "llvm/CodeGen/TargetPassConfig.h"
1773471bf0Spatrick #include "llvm/IR/IntrinsicsAMDGPU.h"
1809467b48Spatrick #include "llvm/IR/IRBuilder.h"
1909467b48Spatrick #include "llvm/IR/MDBuilder.h"
2073471bf0Spatrick #include "llvm/Target/TargetMachine.h"
2109467b48Spatrick #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
2209467b48Spatrick 
2309467b48Spatrick using namespace llvm;
2409467b48Spatrick 
2509467b48Spatrick namespace {
2609467b48Spatrick 
2709467b48Spatrick class AMDGPULowerKernelArguments : public FunctionPass{
2809467b48Spatrick public:
2909467b48Spatrick   static char ID;
3009467b48Spatrick 
AMDGPULowerKernelArguments()3109467b48Spatrick   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
3209467b48Spatrick 
3309467b48Spatrick   bool runOnFunction(Function &F) override;
3409467b48Spatrick 
getAnalysisUsage(AnalysisUsage & AU) const3509467b48Spatrick   void getAnalysisUsage(AnalysisUsage &AU) const override {
3609467b48Spatrick     AU.addRequired<TargetPassConfig>();
3709467b48Spatrick     AU.setPreservesAll();
3809467b48Spatrick  }
3909467b48Spatrick };
4009467b48Spatrick 
4109467b48Spatrick } // end anonymous namespace
4209467b48Spatrick 
43097a140dSpatrick // skip allocas
getInsertPt(BasicBlock & BB)44097a140dSpatrick static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
45097a140dSpatrick   BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
46097a140dSpatrick   for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
47097a140dSpatrick     AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
48097a140dSpatrick 
49097a140dSpatrick     // If this is a dynamic alloca, the value may depend on the loaded kernargs,
50097a140dSpatrick     // so loads will need to be inserted before it.
51097a140dSpatrick     if (!AI || !AI->isStaticAlloca())
52097a140dSpatrick       break;
53097a140dSpatrick   }
54097a140dSpatrick 
55097a140dSpatrick   return InsPt;
56097a140dSpatrick }
57097a140dSpatrick 
runOnFunction(Function & F)5809467b48Spatrick bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
5909467b48Spatrick   CallingConv::ID CC = F.getCallingConv();
6009467b48Spatrick   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
6109467b48Spatrick     return false;
6209467b48Spatrick 
6309467b48Spatrick   auto &TPC = getAnalysis<TargetPassConfig>();
6409467b48Spatrick 
6509467b48Spatrick   const TargetMachine &TM = TPC.getTM<TargetMachine>();
6609467b48Spatrick   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
6709467b48Spatrick   LLVMContext &Ctx = F.getParent()->getContext();
6809467b48Spatrick   const DataLayout &DL = F.getParent()->getDataLayout();
6909467b48Spatrick   BasicBlock &EntryBlock = *F.begin();
70097a140dSpatrick   IRBuilder<> Builder(&*getInsertPt(EntryBlock));
7109467b48Spatrick 
7209467b48Spatrick   const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
7309467b48Spatrick   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
7409467b48Spatrick 
7509467b48Spatrick   Align MaxAlign;
76*d415bd75Srobert   // FIXME: Alignment is broken with explicit arg offset.;
7709467b48Spatrick   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
7809467b48Spatrick   if (TotalKernArgSize == 0)
7909467b48Spatrick     return false;
8009467b48Spatrick 
8109467b48Spatrick   CallInst *KernArgSegment =
8209467b48Spatrick       Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
8309467b48Spatrick                               nullptr, F.getName() + ".kernarg.segment");
8409467b48Spatrick 
85*d415bd75Srobert   KernArgSegment->addRetAttr(Attribute::NonNull);
86*d415bd75Srobert   KernArgSegment->addRetAttr(
8709467b48Spatrick       Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
8809467b48Spatrick 
8909467b48Spatrick   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
9009467b48Spatrick   uint64_t ExplicitArgOffset = 0;
9109467b48Spatrick 
9209467b48Spatrick   for (Argument &Arg : F.args()) {
9373471bf0Spatrick     const bool IsByRef = Arg.hasByRefAttr();
9473471bf0Spatrick     Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType();
95*d415bd75Srobert     MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : std::nullopt;
96*d415bd75Srobert     Align ABITypeAlign = DL.getValueOrABITypeAlignment(ParamAlign, ArgTy);
9773471bf0Spatrick 
9873471bf0Spatrick     uint64_t Size = DL.getTypeSizeInBits(ArgTy);
9973471bf0Spatrick     uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
10009467b48Spatrick 
10109467b48Spatrick     uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
10209467b48Spatrick     ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
10309467b48Spatrick 
10409467b48Spatrick     if (Arg.use_empty())
10509467b48Spatrick       continue;
10609467b48Spatrick 
10773471bf0Spatrick     // If this is byval, the loads are already explicit in the function. We just
10873471bf0Spatrick     // need to rewrite the pointer values.
10973471bf0Spatrick     if (IsByRef) {
11073471bf0Spatrick       Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
11173471bf0Spatrick           Builder.getInt8Ty(), KernArgSegment, EltOffset,
11273471bf0Spatrick           Arg.getName() + ".byval.kernarg.offset");
11373471bf0Spatrick 
11473471bf0Spatrick       Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
11573471bf0Spatrick           ArgOffsetPtr, Arg.getType());
11673471bf0Spatrick       Arg.replaceAllUsesWith(CastOffsetPtr);
11773471bf0Spatrick       continue;
11873471bf0Spatrick     }
11973471bf0Spatrick 
12009467b48Spatrick     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
12109467b48Spatrick       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
12209467b48Spatrick       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
12309467b48Spatrick       // can't represent this with range metadata because it's only allowed for
12409467b48Spatrick       // integer types.
12509467b48Spatrick       if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
12609467b48Spatrick            PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
12709467b48Spatrick           !ST.hasUsableDSOffset())
12809467b48Spatrick         continue;
12909467b48Spatrick 
13009467b48Spatrick       // FIXME: We can replace this with equivalent alias.scope/noalias
13109467b48Spatrick       // metadata, but this appears to be a lot of work.
13209467b48Spatrick       if (Arg.hasNoAliasAttr())
13309467b48Spatrick         continue;
13409467b48Spatrick     }
13509467b48Spatrick 
136097a140dSpatrick     auto *VT = dyn_cast<FixedVectorType>(ArgTy);
13709467b48Spatrick     bool IsV3 = VT && VT->getNumElements() == 3;
13809467b48Spatrick     bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
13909467b48Spatrick 
14009467b48Spatrick     VectorType *V4Ty = nullptr;
14109467b48Spatrick 
14209467b48Spatrick     int64_t AlignDownOffset = alignDown(EltOffset, 4);
14309467b48Spatrick     int64_t OffsetDiff = EltOffset - AlignDownOffset;
14409467b48Spatrick     Align AdjustedAlign = commonAlignment(
14509467b48Spatrick         KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
14609467b48Spatrick 
14709467b48Spatrick     Value *ArgPtr;
14809467b48Spatrick     Type *AdjustedArgTy;
14909467b48Spatrick     if (DoShiftOpt) { // FIXME: Handle aggregate types
15009467b48Spatrick       // Since we don't have sub-dword scalar loads, avoid doing an extload by
15109467b48Spatrick       // loading earlier than the argument address, and extracting the relevant
15209467b48Spatrick       // bits.
15309467b48Spatrick       //
15409467b48Spatrick       // Additionally widen any sub-dword load to i32 even if suitably aligned,
15509467b48Spatrick       // so that CSE between different argument loads works easily.
15609467b48Spatrick       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
15709467b48Spatrick           Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
15809467b48Spatrick           Arg.getName() + ".kernarg.offset.align.down");
15909467b48Spatrick       AdjustedArgTy = Builder.getInt32Ty();
16009467b48Spatrick     } else {
16109467b48Spatrick       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
16209467b48Spatrick           Builder.getInt8Ty(), KernArgSegment, EltOffset,
16309467b48Spatrick           Arg.getName() + ".kernarg.offset");
16409467b48Spatrick       AdjustedArgTy = ArgTy;
16509467b48Spatrick     }
16609467b48Spatrick 
16709467b48Spatrick     if (IsV3 && Size >= 32) {
168097a140dSpatrick       V4Ty = FixedVectorType::get(VT->getElementType(), 4);
16909467b48Spatrick       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
17009467b48Spatrick       AdjustedArgTy = V4Ty;
17109467b48Spatrick     }
17209467b48Spatrick 
17309467b48Spatrick     ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
17409467b48Spatrick                                    ArgPtr->getName() + ".cast");
17509467b48Spatrick     LoadInst *Load =
176097a140dSpatrick         Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
17709467b48Spatrick     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
17809467b48Spatrick 
17909467b48Spatrick     MDBuilder MDB(Ctx);
18009467b48Spatrick 
18109467b48Spatrick     if (isa<PointerType>(ArgTy)) {
18209467b48Spatrick       if (Arg.hasNonNullAttr())
18309467b48Spatrick         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
18409467b48Spatrick 
18509467b48Spatrick       uint64_t DerefBytes = Arg.getDereferenceableBytes();
18609467b48Spatrick       if (DerefBytes != 0) {
18709467b48Spatrick         Load->setMetadata(
18809467b48Spatrick           LLVMContext::MD_dereferenceable,
18909467b48Spatrick           MDNode::get(Ctx,
19009467b48Spatrick                       MDB.createConstant(
19109467b48Spatrick                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
19209467b48Spatrick       }
19309467b48Spatrick 
19409467b48Spatrick       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
19509467b48Spatrick       if (DerefOrNullBytes != 0) {
19609467b48Spatrick         Load->setMetadata(
19709467b48Spatrick           LLVMContext::MD_dereferenceable_or_null,
19809467b48Spatrick           MDNode::get(Ctx,
19909467b48Spatrick                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
20009467b48Spatrick                                                           DerefOrNullBytes))));
20109467b48Spatrick       }
20209467b48Spatrick 
203*d415bd75Srobert       if (MaybeAlign ParamAlign = Arg.getParamAlign()) {
20409467b48Spatrick         Load->setMetadata(
20509467b48Spatrick             LLVMContext::MD_align,
206*d415bd75Srobert             MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
207*d415bd75Srobert                                  Builder.getInt64Ty(), ParamAlign->value()))));
20809467b48Spatrick       }
20909467b48Spatrick     }
21009467b48Spatrick 
21109467b48Spatrick     // TODO: Convert noalias arg to !noalias
21209467b48Spatrick 
21309467b48Spatrick     if (DoShiftOpt) {
21409467b48Spatrick       Value *ExtractBits = OffsetDiff == 0 ?
21509467b48Spatrick         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
21609467b48Spatrick 
21709467b48Spatrick       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
21809467b48Spatrick       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
21909467b48Spatrick       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
22009467b48Spatrick                                             Arg.getName() + ".load");
22109467b48Spatrick       Arg.replaceAllUsesWith(NewVal);
22209467b48Spatrick     } else if (IsV3) {
22373471bf0Spatrick       Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2},
22409467b48Spatrick                                                 Arg.getName() + ".load");
22509467b48Spatrick       Arg.replaceAllUsesWith(Shuf);
22609467b48Spatrick     } else {
22709467b48Spatrick       Load->setName(Arg.getName() + ".load");
22809467b48Spatrick       Arg.replaceAllUsesWith(Load);
22909467b48Spatrick     }
23009467b48Spatrick   }
23109467b48Spatrick 
232*d415bd75Srobert   KernArgSegment->addRetAttr(
23309467b48Spatrick       Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
23409467b48Spatrick 
23509467b48Spatrick   return true;
23609467b48Spatrick }
23709467b48Spatrick 
23809467b48Spatrick INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
23909467b48Spatrick                       "AMDGPU Lower Kernel Arguments", false, false)
24009467b48Spatrick INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
24109467b48Spatrick                     false, false)
24209467b48Spatrick 
24309467b48Spatrick char AMDGPULowerKernelArguments::ID = 0;
24409467b48Spatrick 
createAMDGPULowerKernelArgumentsPass()24509467b48Spatrick FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
24609467b48Spatrick   return new AMDGPULowerKernelArguments();
24709467b48Spatrick }
248