109467b48Spatrick //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick /// \file This pass replaces accesses to kernel arguments with loads from
1009467b48Spatrick /// offsets from the kernarg base pointer.
1109467b48Spatrick //
1209467b48Spatrick //===----------------------------------------------------------------------===//
1309467b48Spatrick
1409467b48Spatrick #include "AMDGPU.h"
1573471bf0Spatrick #include "GCNSubtarget.h"
1609467b48Spatrick #include "llvm/CodeGen/TargetPassConfig.h"
1773471bf0Spatrick #include "llvm/IR/IntrinsicsAMDGPU.h"
1809467b48Spatrick #include "llvm/IR/IRBuilder.h"
1909467b48Spatrick #include "llvm/IR/MDBuilder.h"
2073471bf0Spatrick #include "llvm/Target/TargetMachine.h"
2109467b48Spatrick #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
2209467b48Spatrick
2309467b48Spatrick using namespace llvm;
2409467b48Spatrick
2509467b48Spatrick namespace {
2609467b48Spatrick
2709467b48Spatrick class AMDGPULowerKernelArguments : public FunctionPass{
2809467b48Spatrick public:
2909467b48Spatrick static char ID;
3009467b48Spatrick
AMDGPULowerKernelArguments()3109467b48Spatrick AMDGPULowerKernelArguments() : FunctionPass(ID) {}
3209467b48Spatrick
3309467b48Spatrick bool runOnFunction(Function &F) override;
3409467b48Spatrick
getAnalysisUsage(AnalysisUsage & AU) const3509467b48Spatrick void getAnalysisUsage(AnalysisUsage &AU) const override {
3609467b48Spatrick AU.addRequired<TargetPassConfig>();
3709467b48Spatrick AU.setPreservesAll();
3809467b48Spatrick }
3909467b48Spatrick };
4009467b48Spatrick
4109467b48Spatrick } // end anonymous namespace
4209467b48Spatrick
43097a140dSpatrick // skip allocas
getInsertPt(BasicBlock & BB)44097a140dSpatrick static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
45097a140dSpatrick BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
46097a140dSpatrick for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
47097a140dSpatrick AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
48097a140dSpatrick
49097a140dSpatrick // If this is a dynamic alloca, the value may depend on the loaded kernargs,
50097a140dSpatrick // so loads will need to be inserted before it.
51097a140dSpatrick if (!AI || !AI->isStaticAlloca())
52097a140dSpatrick break;
53097a140dSpatrick }
54097a140dSpatrick
55097a140dSpatrick return InsPt;
56097a140dSpatrick }
57097a140dSpatrick
runOnFunction(Function & F)5809467b48Spatrick bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
5909467b48Spatrick CallingConv::ID CC = F.getCallingConv();
6009467b48Spatrick if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
6109467b48Spatrick return false;
6209467b48Spatrick
6309467b48Spatrick auto &TPC = getAnalysis<TargetPassConfig>();
6409467b48Spatrick
6509467b48Spatrick const TargetMachine &TM = TPC.getTM<TargetMachine>();
6609467b48Spatrick const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
6709467b48Spatrick LLVMContext &Ctx = F.getParent()->getContext();
6809467b48Spatrick const DataLayout &DL = F.getParent()->getDataLayout();
6909467b48Spatrick BasicBlock &EntryBlock = *F.begin();
70097a140dSpatrick IRBuilder<> Builder(&*getInsertPt(EntryBlock));
7109467b48Spatrick
7209467b48Spatrick const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
7309467b48Spatrick const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
7409467b48Spatrick
7509467b48Spatrick Align MaxAlign;
76*d415bd75Srobert // FIXME: Alignment is broken with explicit arg offset.;
7709467b48Spatrick const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
7809467b48Spatrick if (TotalKernArgSize == 0)
7909467b48Spatrick return false;
8009467b48Spatrick
8109467b48Spatrick CallInst *KernArgSegment =
8209467b48Spatrick Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
8309467b48Spatrick nullptr, F.getName() + ".kernarg.segment");
8409467b48Spatrick
85*d415bd75Srobert KernArgSegment->addRetAttr(Attribute::NonNull);
86*d415bd75Srobert KernArgSegment->addRetAttr(
8709467b48Spatrick Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
8809467b48Spatrick
8909467b48Spatrick unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
9009467b48Spatrick uint64_t ExplicitArgOffset = 0;
9109467b48Spatrick
9209467b48Spatrick for (Argument &Arg : F.args()) {
9373471bf0Spatrick const bool IsByRef = Arg.hasByRefAttr();
9473471bf0Spatrick Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType();
95*d415bd75Srobert MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : std::nullopt;
96*d415bd75Srobert Align ABITypeAlign = DL.getValueOrABITypeAlignment(ParamAlign, ArgTy);
9773471bf0Spatrick
9873471bf0Spatrick uint64_t Size = DL.getTypeSizeInBits(ArgTy);
9973471bf0Spatrick uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
10009467b48Spatrick
10109467b48Spatrick uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
10209467b48Spatrick ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
10309467b48Spatrick
10409467b48Spatrick if (Arg.use_empty())
10509467b48Spatrick continue;
10609467b48Spatrick
10773471bf0Spatrick // If this is byval, the loads are already explicit in the function. We just
10873471bf0Spatrick // need to rewrite the pointer values.
10973471bf0Spatrick if (IsByRef) {
11073471bf0Spatrick Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
11173471bf0Spatrick Builder.getInt8Ty(), KernArgSegment, EltOffset,
11273471bf0Spatrick Arg.getName() + ".byval.kernarg.offset");
11373471bf0Spatrick
11473471bf0Spatrick Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
11573471bf0Spatrick ArgOffsetPtr, Arg.getType());
11673471bf0Spatrick Arg.replaceAllUsesWith(CastOffsetPtr);
11773471bf0Spatrick continue;
11873471bf0Spatrick }
11973471bf0Spatrick
12009467b48Spatrick if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
12109467b48Spatrick // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
12209467b48Spatrick // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
12309467b48Spatrick // can't represent this with range metadata because it's only allowed for
12409467b48Spatrick // integer types.
12509467b48Spatrick if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
12609467b48Spatrick PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
12709467b48Spatrick !ST.hasUsableDSOffset())
12809467b48Spatrick continue;
12909467b48Spatrick
13009467b48Spatrick // FIXME: We can replace this with equivalent alias.scope/noalias
13109467b48Spatrick // metadata, but this appears to be a lot of work.
13209467b48Spatrick if (Arg.hasNoAliasAttr())
13309467b48Spatrick continue;
13409467b48Spatrick }
13509467b48Spatrick
136097a140dSpatrick auto *VT = dyn_cast<FixedVectorType>(ArgTy);
13709467b48Spatrick bool IsV3 = VT && VT->getNumElements() == 3;
13809467b48Spatrick bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
13909467b48Spatrick
14009467b48Spatrick VectorType *V4Ty = nullptr;
14109467b48Spatrick
14209467b48Spatrick int64_t AlignDownOffset = alignDown(EltOffset, 4);
14309467b48Spatrick int64_t OffsetDiff = EltOffset - AlignDownOffset;
14409467b48Spatrick Align AdjustedAlign = commonAlignment(
14509467b48Spatrick KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
14609467b48Spatrick
14709467b48Spatrick Value *ArgPtr;
14809467b48Spatrick Type *AdjustedArgTy;
14909467b48Spatrick if (DoShiftOpt) { // FIXME: Handle aggregate types
15009467b48Spatrick // Since we don't have sub-dword scalar loads, avoid doing an extload by
15109467b48Spatrick // loading earlier than the argument address, and extracting the relevant
15209467b48Spatrick // bits.
15309467b48Spatrick //
15409467b48Spatrick // Additionally widen any sub-dword load to i32 even if suitably aligned,
15509467b48Spatrick // so that CSE between different argument loads works easily.
15609467b48Spatrick ArgPtr = Builder.CreateConstInBoundsGEP1_64(
15709467b48Spatrick Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
15809467b48Spatrick Arg.getName() + ".kernarg.offset.align.down");
15909467b48Spatrick AdjustedArgTy = Builder.getInt32Ty();
16009467b48Spatrick } else {
16109467b48Spatrick ArgPtr = Builder.CreateConstInBoundsGEP1_64(
16209467b48Spatrick Builder.getInt8Ty(), KernArgSegment, EltOffset,
16309467b48Spatrick Arg.getName() + ".kernarg.offset");
16409467b48Spatrick AdjustedArgTy = ArgTy;
16509467b48Spatrick }
16609467b48Spatrick
16709467b48Spatrick if (IsV3 && Size >= 32) {
168097a140dSpatrick V4Ty = FixedVectorType::get(VT->getElementType(), 4);
16909467b48Spatrick // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
17009467b48Spatrick AdjustedArgTy = V4Ty;
17109467b48Spatrick }
17209467b48Spatrick
17309467b48Spatrick ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
17409467b48Spatrick ArgPtr->getName() + ".cast");
17509467b48Spatrick LoadInst *Load =
176097a140dSpatrick Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
17709467b48Spatrick Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
17809467b48Spatrick
17909467b48Spatrick MDBuilder MDB(Ctx);
18009467b48Spatrick
18109467b48Spatrick if (isa<PointerType>(ArgTy)) {
18209467b48Spatrick if (Arg.hasNonNullAttr())
18309467b48Spatrick Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
18409467b48Spatrick
18509467b48Spatrick uint64_t DerefBytes = Arg.getDereferenceableBytes();
18609467b48Spatrick if (DerefBytes != 0) {
18709467b48Spatrick Load->setMetadata(
18809467b48Spatrick LLVMContext::MD_dereferenceable,
18909467b48Spatrick MDNode::get(Ctx,
19009467b48Spatrick MDB.createConstant(
19109467b48Spatrick ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
19209467b48Spatrick }
19309467b48Spatrick
19409467b48Spatrick uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
19509467b48Spatrick if (DerefOrNullBytes != 0) {
19609467b48Spatrick Load->setMetadata(
19709467b48Spatrick LLVMContext::MD_dereferenceable_or_null,
19809467b48Spatrick MDNode::get(Ctx,
19909467b48Spatrick MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
20009467b48Spatrick DerefOrNullBytes))));
20109467b48Spatrick }
20209467b48Spatrick
203*d415bd75Srobert if (MaybeAlign ParamAlign = Arg.getParamAlign()) {
20409467b48Spatrick Load->setMetadata(
20509467b48Spatrick LLVMContext::MD_align,
206*d415bd75Srobert MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
207*d415bd75Srobert Builder.getInt64Ty(), ParamAlign->value()))));
20809467b48Spatrick }
20909467b48Spatrick }
21009467b48Spatrick
21109467b48Spatrick // TODO: Convert noalias arg to !noalias
21209467b48Spatrick
21309467b48Spatrick if (DoShiftOpt) {
21409467b48Spatrick Value *ExtractBits = OffsetDiff == 0 ?
21509467b48Spatrick Load : Builder.CreateLShr(Load, OffsetDiff * 8);
21609467b48Spatrick
21709467b48Spatrick IntegerType *ArgIntTy = Builder.getIntNTy(Size);
21809467b48Spatrick Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
21909467b48Spatrick Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
22009467b48Spatrick Arg.getName() + ".load");
22109467b48Spatrick Arg.replaceAllUsesWith(NewVal);
22209467b48Spatrick } else if (IsV3) {
22373471bf0Spatrick Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2},
22409467b48Spatrick Arg.getName() + ".load");
22509467b48Spatrick Arg.replaceAllUsesWith(Shuf);
22609467b48Spatrick } else {
22709467b48Spatrick Load->setName(Arg.getName() + ".load");
22809467b48Spatrick Arg.replaceAllUsesWith(Load);
22909467b48Spatrick }
23009467b48Spatrick }
23109467b48Spatrick
232*d415bd75Srobert KernArgSegment->addRetAttr(
23309467b48Spatrick Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
23409467b48Spatrick
23509467b48Spatrick return true;
23609467b48Spatrick }
23709467b48Spatrick
23809467b48Spatrick INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
23909467b48Spatrick "AMDGPU Lower Kernel Arguments", false, false)
24009467b48Spatrick INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
24109467b48Spatrick false, false)
24209467b48Spatrick
24309467b48Spatrick char AMDGPULowerKernelArguments::ID = 0;
24409467b48Spatrick
createAMDGPULowerKernelArgumentsPass()24509467b48Spatrick FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
24609467b48Spatrick return new AMDGPULowerKernelArguments();
24709467b48Spatrick }
248