1 //===- ROCDLAttachTarget.cpp - Attach an ROCDL target ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the `GpuROCDLAttachTarget` pass, attaching 10 // `#rocdl.target` attributes to GPU modules. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/Transforms/Passes.h" 15 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Target/LLVM/ROCDL/Target.h" 21 #include "llvm/Support/Regex.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_GPUROCDLATTACHTARGET 25 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::ROCDL; 30 31 namespace { 32 struct ROCDLAttachTarget 33 : public impl::GpuROCDLAttachTargetBase<ROCDLAttachTarget> { 34 using Base::Base; 35 36 DictionaryAttr getFlags(OpBuilder &builder) const; 37 38 void runOnOperation() override; 39 40 void getDependentDialects(DialectRegistry ®istry) const override { 41 registry.insert<ROCDL::ROCDLDialect>(); 42 } 43 }; 44 } // namespace 45 46 DictionaryAttr ROCDLAttachTarget::getFlags(OpBuilder &builder) const { 47 UnitAttr unitAttr = builder.getUnitAttr(); 48 SmallVector<NamedAttribute, 6> flags; 49 auto addFlag = [&](StringRef flag) { 50 flags.push_back(builder.getNamedAttr(flag, unitAttr)); 51 }; 52 if (!wave64Flag) 53 addFlag("no_wave64"); 54 if (fastFlag) 55 addFlag("fast"); 56 if (dazFlag) 57 addFlag("daz"); 58 if (finiteOnlyFlag) 59 addFlag("finite_only"); 60 if (unsafeMathFlag) 61 addFlag("unsafe_math"); 62 if (!correctSqrtFlag) 63 addFlag("unsafe_sqrt"); 64 if (!flags.empty()) 65 return builder.getDictionaryAttr(flags); 66 return nullptr; 67 } 68 69 void ROCDLAttachTarget::runOnOperation() { 70 OpBuilder builder(&getContext()); 71 ArrayRef<std::string> libs(linkLibs); 72 SmallVector<StringRef> filesToLink(libs); 73 auto target = builder.getAttr<ROCDLTargetAttr>( 74 optLevel, triple, chip, features, abiVersion, getFlags(builder), 75 filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink)); 76 llvm::Regex matcher(moduleMatcher); 77 for (Region ®ion : getOperation()->getRegions()) 78 for (Block &block : region.getBlocks()) 79 for (auto module : block.getOps<gpu::GPUModuleOp>()) { 80 // Check if the name of the module matches. 81 if (!moduleMatcher.empty() && !matcher.match(module.getName())) 82 continue; 83 // Create the target array. 84 SmallVector<Attribute> targets; 85 if (std::optional<ArrayAttr> attrs = module.getTargets()) 86 targets.append(attrs->getValue().begin(), attrs->getValue().end()); 87 targets.push_back(target); 88 // Remove any duplicate targets. 89 targets.erase(llvm::unique(targets), targets.end()); 90 // Update the target attribute array. 91 module.setTargetsAttr(builder.getArrayAttr(targets)); 92 } 93 } 94