1 //===- NVVMAttachTarget.cpp - Attach an NVVM target -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the `GpuNVVMAttachTarget` pass, attaching `#nvvm.target` 10 // attributes to GPU modules. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/Transforms/Passes.h" 15 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Target/LLVM/NVVM/Target.h" 21 #include "llvm/Support/Regex.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_GPUNVVMATTACHTARGET 25 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::NVVM; 30 31 namespace { 32 struct NVVMAttachTarget 33 : public impl::GpuNVVMAttachTargetBase<NVVMAttachTarget> { 34 using Base::Base; 35 36 DictionaryAttr getFlags(OpBuilder &builder) const; 37 38 void runOnOperation() override; 39 40 void getDependentDialects(DialectRegistry ®istry) const override { 41 registry.insert<NVVM::NVVMDialect>(); 42 } 43 }; 44 } // namespace 45 46 DictionaryAttr NVVMAttachTarget::getFlags(OpBuilder &builder) const { 47 UnitAttr unitAttr = builder.getUnitAttr(); 48 SmallVector<NamedAttribute, 2> flags; 49 auto addFlag = [&](StringRef flag) { 50 flags.push_back(builder.getNamedAttr(flag, unitAttr)); 51 }; 52 if (fastFlag) 53 addFlag("fast"); 54 if (ftzFlag) 55 addFlag("ftz"); 56 if (!flags.empty()) 57 return builder.getDictionaryAttr(flags); 58 return nullptr; 59 } 60 61 void NVVMAttachTarget::runOnOperation() { 62 OpBuilder builder(&getContext()); 63 ArrayRef<std::string> libs(linkLibs); 64 SmallVector<StringRef> filesToLink(libs); 65 auto target = builder.getAttr<NVVMTargetAttr>( 66 optLevel, triple, chip, features, getFlags(builder), 67 filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink)); 68 llvm::Regex matcher(moduleMatcher); 69 for (Region ®ion : getOperation()->getRegions()) 70 for (Block &block : region.getBlocks()) 71 for (auto module : block.getOps<gpu::GPUModuleOp>()) { 72 // Check if the name of the module matches. 73 if (!moduleMatcher.empty() && !matcher.match(module.getName())) 74 continue; 75 // Create the target array. 76 SmallVector<Attribute> targets; 77 if (std::optional<ArrayAttr> attrs = module.getTargets()) 78 targets.append(attrs->getValue().begin(), attrs->getValue().end()); 79 targets.push_back(target); 80 // Remove any duplicate targets. 81 targets.erase(llvm::unique(targets), targets.end()); 82 // Update the target attribute array. 83 module.setTargetsAttr(builder.getArrayAttr(targets)); 84 } 85 } 86