1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass resolves calls to OpenCL image attribute, image resource ID and 11 /// sampler resource ID getter functions. 12 /// 13 /// Image attributes (size and format) are expected to be passed to the kernel 14 /// as kernel arguments immediately following the image argument itself, 15 /// therefore this pass adds image size and format arguments to the kernel 16 /// functions in the module. The kernel functions with image arguments are 17 /// re-created using the new signature. The new arguments are added to the 18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format". 19 /// Note: this pass may invalidate pointers to functions. 20 /// 21 /// Resource IDs of read-only images, write-only images and samplers are 22 /// defined to be their index among the kernel arguments of the same 23 /// type and access qualifier. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "R600.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/IR/Constants.h" 31 #include "llvm/IR/Function.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/Metadata.h" 34 #include "llvm/IR/Module.h" 35 #include "llvm/Pass.h" 36 #include "llvm/Transforms/Utils/Cloning.h" 37 38 using namespace llvm; 39 40 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size"; 41 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format"; 42 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id"; 43 static StringRef GetSamplerResourceIDFunc = 44 "llvm.OpenCL.sampler.get.resource.id"; 45 46 static StringRef ImageSizeArgMDType = "__llvm_image_size"; 47 static StringRef ImageFormatArgMDType = "__llvm_image_format"; 48 49 static StringRef KernelsMDNodeName = "opencl.kernels"; 50 static StringRef KernelArgMDNodeNames[] = { 51 "kernel_arg_addr_space", 52 "kernel_arg_access_qual", 53 "kernel_arg_type", 54 "kernel_arg_base_type", 55 "kernel_arg_type_qual"}; 56 static const unsigned NumKernelArgMDNodes = 5; 57 58 namespace { 59 60 using MDVector = SmallVector<Metadata *, 8>; 61 struct KernelArgMD { 62 MDVector ArgVector[NumKernelArgMDNodes]; 63 }; 64 65 } // end anonymous namespace 66 67 static inline bool 68 IsImageType(StringRef TypeString) { 69 return TypeString == "image2d_t" || TypeString == "image3d_t"; 70 } 71 72 static inline bool 73 IsSamplerType(StringRef TypeString) { 74 return TypeString == "sampler_t"; 75 } 76 77 static Function * 78 GetFunctionFromMDNode(MDNode *Node) { 79 if (!Node) 80 return nullptr; 81 82 size_t NumOps = Node->getNumOperands(); 83 if (NumOps != NumKernelArgMDNodes + 1) 84 return nullptr; 85 86 auto *F = mdconst::dyn_extract<Function>(Node->getOperand(0)); 87 if (!F) 88 return nullptr; 89 90 // Validation checks. 91 size_t ExpectNumArgNodeOps = F->arg_size() + 1; 92 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) { 93 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1)); 94 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps) 95 return nullptr; 96 if (!ArgNode->getOperand(0)) 97 return nullptr; 98 99 // FIXME: It should be possible to do image lowering when some metadata 100 // args missing or not in the expected order. 101 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0)); 102 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i]) 103 return nullptr; 104 } 105 106 return F; 107 } 108 109 static StringRef 110 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 111 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2)); 112 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString(); 113 } 114 115 static StringRef 116 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 117 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3)); 118 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString(); 119 } 120 121 static MDVector 122 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) { 123 MDVector Res; 124 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 125 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1)); 126 Res.push_back(Node->getOperand(OpIdx)); 127 } 128 return Res; 129 } 130 131 static void 132 PushArgMD(KernelArgMD &MD, const MDVector &V) { 133 assert(V.size() == NumKernelArgMDNodes); 134 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 135 MD.ArgVector[i].push_back(V[i]); 136 } 137 } 138 139 namespace { 140 141 class R600OpenCLImageTypeLoweringPass : public ModulePass { 142 static char ID; 143 144 LLVMContext *Context; 145 Type *Int32Type; 146 Type *ImageSizeType; 147 Type *ImageFormatType; 148 SmallVector<Instruction *, 4> InstsToErase; 149 150 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID, 151 Argument &ImageSizeArg, 152 Argument &ImageFormatArg) { 153 bool Modified = false; 154 155 for (auto &Use : ImageArg.uses()) { 156 auto *Inst = dyn_cast<CallInst>(Use.getUser()); 157 if (!Inst) { 158 continue; 159 } 160 161 Function *F = Inst->getCalledFunction(); 162 if (!F) 163 continue; 164 165 Value *Replacement = nullptr; 166 StringRef Name = F->getName(); 167 if (Name.starts_with(GetImageResourceIDFunc)) { 168 Replacement = ConstantInt::get(Int32Type, ResourceID); 169 } else if (Name.starts_with(GetImageSizeFunc)) { 170 Replacement = &ImageSizeArg; 171 } else if (Name.starts_with(GetImageFormatFunc)) { 172 Replacement = &ImageFormatArg; 173 } else { 174 continue; 175 } 176 177 Inst->replaceAllUsesWith(Replacement); 178 InstsToErase.push_back(Inst); 179 Modified = true; 180 } 181 182 return Modified; 183 } 184 185 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) { 186 bool Modified = false; 187 188 for (const auto &Use : SamplerArg.uses()) { 189 auto *Inst = dyn_cast<CallInst>(Use.getUser()); 190 if (!Inst) { 191 continue; 192 } 193 194 Function *F = Inst->getCalledFunction(); 195 if (!F) 196 continue; 197 198 Value *Replacement = nullptr; 199 StringRef Name = F->getName(); 200 if (Name == GetSamplerResourceIDFunc) { 201 Replacement = ConstantInt::get(Int32Type, ResourceID); 202 } else { 203 continue; 204 } 205 206 Inst->replaceAllUsesWith(Replacement); 207 InstsToErase.push_back(Inst); 208 Modified = true; 209 } 210 211 return Modified; 212 } 213 214 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) { 215 uint32_t NumReadOnlyImageArgs = 0; 216 uint32_t NumWriteOnlyImageArgs = 0; 217 uint32_t NumSamplerArgs = 0; 218 219 bool Modified = false; 220 InstsToErase.clear(); 221 for (auto *ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) { 222 Argument &Arg = *ArgI; 223 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo()); 224 225 // Handle image types. 226 if (IsImageType(Type)) { 227 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo()); 228 uint32_t ResourceID; 229 if (AccessQual == "read_only") { 230 ResourceID = NumReadOnlyImageArgs++; 231 } else if (AccessQual == "write_only") { 232 ResourceID = NumWriteOnlyImageArgs++; 233 } else { 234 llvm_unreachable("Wrong image access qualifier."); 235 } 236 237 Argument &SizeArg = *(++ArgI); 238 Argument &FormatArg = *(++ArgI); 239 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg); 240 241 // Handle sampler type. 242 } else if (IsSamplerType(Type)) { 243 uint32_t ResourceID = NumSamplerArgs++; 244 Modified |= replaceSamplerUses(Arg, ResourceID); 245 } 246 } 247 for (auto *Inst : InstsToErase) 248 Inst->eraseFromParent(); 249 250 return Modified; 251 } 252 253 std::tuple<Function *, MDNode *> 254 addImplicitArgs(Function *F, MDNode *KernelMDNode) { 255 bool Modified = false; 256 257 FunctionType *FT = F->getFunctionType(); 258 SmallVector<Type *, 8> ArgTypes; 259 260 // Metadata operands for new MDNode. 261 KernelArgMD NewArgMDs; 262 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0)); 263 264 // Add implicit arguments to the signature. 265 for (unsigned i = 0; i < FT->getNumParams(); ++i) { 266 ArgTypes.push_back(FT->getParamType(i)); 267 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1); 268 PushArgMD(NewArgMDs, ArgMD); 269 270 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i))) 271 continue; 272 273 // Add size implicit argument. 274 ArgTypes.push_back(ImageSizeType); 275 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType); 276 PushArgMD(NewArgMDs, ArgMD); 277 278 // Add format implicit argument. 279 ArgTypes.push_back(ImageFormatType); 280 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType); 281 PushArgMD(NewArgMDs, ArgMD); 282 283 Modified = true; 284 } 285 if (!Modified) { 286 return std::tuple(nullptr, nullptr); 287 } 288 289 // Create function with new signature and clone the old body into it. 290 auto *NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false); 291 auto *NewF = Function::Create(NewFT, F->getLinkage(), F->getName()); 292 ValueToValueMapTy VMap; 293 auto *NewFArgIt = NewF->arg_begin(); 294 for (auto &Arg: F->args()) { 295 auto ArgName = Arg.getName(); 296 NewFArgIt->setName(ArgName); 297 VMap[&Arg] = &(*NewFArgIt++); 298 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) { 299 (NewFArgIt++)->setName(Twine("__size_") + ArgName); 300 (NewFArgIt++)->setName(Twine("__format_") + ArgName); 301 } 302 } 303 SmallVector<ReturnInst*, 8> Returns; 304 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 305 Returns); 306 307 // Build new MDNode. 308 SmallVector<Metadata *, 6> KernelMDArgs; 309 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF)); 310 for (const MDVector &MDV : NewArgMDs.ArgVector) 311 KernelMDArgs.push_back(MDNode::get(*Context, MDV)); 312 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs); 313 314 return std::tuple(NewF, NewMDNode); 315 } 316 317 bool transformKernels(Module &M) { 318 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName); 319 if (!KernelsMDNode) 320 return false; 321 322 bool Modified = false; 323 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) { 324 MDNode *KernelMDNode = KernelsMDNode->getOperand(i); 325 Function *F = GetFunctionFromMDNode(KernelMDNode); 326 if (!F) 327 continue; 328 329 Function *NewF; 330 MDNode *NewMDNode; 331 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode); 332 if (NewF) { 333 // Replace old function and metadata with new ones. 334 F->eraseFromParent(); 335 M.getFunctionList().push_back(NewF); 336 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(), 337 NewF->getAttributes()); 338 KernelsMDNode->setOperand(i, NewMDNode); 339 340 F = NewF; 341 KernelMDNode = NewMDNode; 342 Modified = true; 343 } 344 345 Modified |= replaceImageAndSamplerUses(F, KernelMDNode); 346 } 347 348 return Modified; 349 } 350 351 public: 352 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {} 353 354 bool runOnModule(Module &M) override { 355 Context = &M.getContext(); 356 Int32Type = Type::getInt32Ty(M.getContext()); 357 ImageSizeType = ArrayType::get(Int32Type, 3); 358 ImageFormatType = ArrayType::get(Int32Type, 2); 359 360 return transformKernels(M); 361 } 362 363 StringRef getPassName() const override { 364 return "R600 OpenCL Image Type Pass"; 365 } 366 }; 367 368 } // end anonymous namespace 369 370 char R600OpenCLImageTypeLoweringPass::ID = 0; 371 372 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() { 373 return new R600OpenCLImageTypeLoweringPass(); 374 } 375