1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file This pass does attempts to make use of reqd_work_group_size metadata 10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL 11 /// get_local_size-like functions. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/CodeGen/Passes.h" 19 #include "llvm/CodeGen/TargetPassConfig.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/InstIterator.h" 23 #include "llvm/IR/Instructions.h" 24 #include "llvm/IR/PassManager.h" 25 #include "llvm/IR/PatternMatch.h" 26 #include "llvm/Pass.h" 27 28 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes" 29 30 using namespace llvm; 31 32 namespace { 33 34 // Field offsets in hsa_kernel_dispatch_packet_t. 35 enum DispatchPackedOffsets { 36 WORKGROUP_SIZE_X = 4, 37 WORKGROUP_SIZE_Y = 6, 38 WORKGROUP_SIZE_Z = 8, 39 40 GRID_SIZE_X = 12, 41 GRID_SIZE_Y = 16, 42 GRID_SIZE_Z = 20 43 }; 44 45 class AMDGPULowerKernelAttributes : public ModulePass { 46 public: 47 static char ID; 48 49 AMDGPULowerKernelAttributes() : ModulePass(ID) {} 50 51 bool runOnModule(Module &M) override; 52 53 StringRef getPassName() const override { 54 return "AMDGPU Kernel Attributes"; 55 } 56 57 void getAnalysisUsage(AnalysisUsage &AU) const override { 58 AU.setPreservesAll(); 59 } 60 }; 61 62 } // end anonymous namespace 63 64 static bool processUse(CallInst *CI) { 65 Function *F = CI->getParent()->getParent(); 66 67 auto MD = F->getMetadata("reqd_work_group_size"); 68 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3; 69 70 const bool HasUniformWorkGroupSize = 71 F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true"; 72 73 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize) 74 return false; 75 76 Value *WorkGroupSizeX = nullptr; 77 Value *WorkGroupSizeY = nullptr; 78 Value *WorkGroupSizeZ = nullptr; 79 80 Value *GridSizeX = nullptr; 81 Value *GridSizeY = nullptr; 82 Value *GridSizeZ = nullptr; 83 84 const DataLayout &DL = F->getParent()->getDataLayout(); 85 86 // We expect to see several GEP users, casted to the appropriate type and 87 // loaded. 88 for (User *U : CI->users()) { 89 if (!U->hasOneUse()) 90 continue; 91 92 int64_t Offset = 0; 93 if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI) 94 continue; 95 96 auto *BCI = dyn_cast<BitCastInst>(*U->user_begin()); 97 if (!BCI || !BCI->hasOneUse()) 98 continue; 99 100 auto *Load = dyn_cast<LoadInst>(*BCI->user_begin()); 101 if (!Load || !Load->isSimple()) 102 continue; 103 104 unsigned LoadSize = DL.getTypeStoreSize(Load->getType()); 105 106 // TODO: Handle merged loads. 107 switch (Offset) { 108 case WORKGROUP_SIZE_X: 109 if (LoadSize == 2) 110 WorkGroupSizeX = Load; 111 break; 112 case WORKGROUP_SIZE_Y: 113 if (LoadSize == 2) 114 WorkGroupSizeY = Load; 115 break; 116 case WORKGROUP_SIZE_Z: 117 if (LoadSize == 2) 118 WorkGroupSizeZ = Load; 119 break; 120 case GRID_SIZE_X: 121 if (LoadSize == 4) 122 GridSizeX = Load; 123 break; 124 case GRID_SIZE_Y: 125 if (LoadSize == 4) 126 GridSizeY = Load; 127 break; 128 case GRID_SIZE_Z: 129 if (LoadSize == 4) 130 GridSizeZ = Load; 131 break; 132 default: 133 break; 134 } 135 } 136 137 // Pattern match the code used to handle partial workgroup dispatches in the 138 // library implementation of get_local_size, so the entire function can be 139 // constant folded with a known group size. 140 // 141 // uint r = grid_size - group_id * group_size; 142 // get_local_size = (r < group_size) ? r : group_size; 143 // 144 // If we have uniform-work-group-size (which is the default in OpenCL 1.2), 145 // the grid_size is required to be a multiple of group_size). In this case: 146 // 147 // grid_size - (group_id * group_size) < group_size 148 // -> 149 // grid_size < group_size + (group_id * group_size) 150 // 151 // (grid_size / group_size) < 1 + group_id 152 // 153 // grid_size / group_size is at least 1, so we can conclude the select 154 // condition is false (except for group_id == 0, where the select result is 155 // the same). 156 157 bool MadeChange = false; 158 Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ }; 159 Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ }; 160 161 for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) { 162 Value *GroupSize = WorkGroupSizes[I]; 163 Value *GridSize = GridSizes[I]; 164 if (!GroupSize || !GridSize) 165 continue; 166 167 for (User *U : GroupSize->users()) { 168 auto *ZextGroupSize = dyn_cast<ZExtInst>(U); 169 if (!ZextGroupSize) 170 continue; 171 172 for (User *ZextUser : ZextGroupSize->users()) { 173 auto *SI = dyn_cast<SelectInst>(ZextUser); 174 if (!SI) 175 continue; 176 177 using namespace llvm::PatternMatch; 178 auto GroupIDIntrin = I == 0 ? 179 m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() : 180 (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() : 181 m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>()); 182 183 auto SubExpr = m_Sub(m_Specific(GridSize), 184 m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))); 185 186 ICmpInst::Predicate Pred; 187 if (match(SI, 188 m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)), 189 SubExpr, 190 m_Specific(ZextGroupSize))) && 191 Pred == ICmpInst::ICMP_ULT) { 192 if (HasReqdWorkGroupSize) { 193 ConstantInt *KnownSize 194 = mdconst::extract<ConstantInt>(MD->getOperand(I)); 195 SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize, 196 SI->getType(), 197 false)); 198 } else { 199 SI->replaceAllUsesWith(ZextGroupSize); 200 } 201 202 MadeChange = true; 203 } 204 } 205 } 206 } 207 208 if (!HasReqdWorkGroupSize) 209 return MadeChange; 210 211 // Eliminate any other loads we can from the dispatch packet. 212 for (int I = 0; I < 3; ++I) { 213 Value *GroupSize = WorkGroupSizes[I]; 214 if (!GroupSize) 215 continue; 216 217 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I)); 218 GroupSize->replaceAllUsesWith( 219 ConstantExpr::getIntegerCast(KnownSize, 220 GroupSize->getType(), 221 false)); 222 MadeChange = true; 223 } 224 225 return MadeChange; 226 } 227 228 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get 229 // TargetPassConfig for subtarget. 230 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) { 231 StringRef DispatchPtrName 232 = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 233 234 Function *DispatchPtr = M.getFunction(DispatchPtrName); 235 if (!DispatchPtr) // Dispatch ptr not used. 236 return false; 237 238 bool MadeChange = false; 239 240 SmallPtrSet<Instruction *, 4> HandledUses; 241 for (auto *U : DispatchPtr->users()) { 242 CallInst *CI = cast<CallInst>(U); 243 if (HandledUses.insert(CI).second) { 244 if (processUse(CI)) 245 MadeChange = true; 246 } 247 } 248 249 return MadeChange; 250 } 251 252 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, 253 "AMDGPU IR optimizations", false, false) 254 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations", 255 false, false) 256 257 char AMDGPULowerKernelAttributes::ID = 0; 258 259 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() { 260 return new AMDGPULowerKernelAttributes(); 261 } 262 263 PreservedAnalyses 264 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) { 265 StringRef DispatchPtrName = 266 Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 267 268 Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName); 269 if (!DispatchPtr) // Dispatch ptr not used. 270 return PreservedAnalyses::all(); 271 272 for (Instruction &I : instructions(F)) { 273 if (CallInst *CI = dyn_cast<CallInst>(&I)) { 274 if (CI->getCalledFunction() == DispatchPtr) 275 processUse(CI); 276 } 277 } 278 279 return PreservedAnalyses::all(); 280 } 281