xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/ROCDLAttachTarget.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
1 //===- ROCDLAttachTarget.cpp - Attach an ROCDL target ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `GpuROCDLAttachTarget` pass, attaching
10 // `#rocdl.target` attributes to GPU modules.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Target/LLVM/ROCDL/Target.h"
21 #include "llvm/Support/Regex.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_GPUROCDLATTACHTARGET
25 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::ROCDL;
30 
31 namespace {
32 struct ROCDLAttachTarget
33     : public impl::GpuROCDLAttachTargetBase<ROCDLAttachTarget> {
34   using Base::Base;
35 
36   DictionaryAttr getFlags(OpBuilder &builder) const;
37 
38   void runOnOperation() override;
39 
40   void getDependentDialects(DialectRegistry &registry) const override {
41     registry.insert<ROCDL::ROCDLDialect>();
42   }
43 };
44 } // namespace
45 
46 DictionaryAttr ROCDLAttachTarget::getFlags(OpBuilder &builder) const {
47   UnitAttr unitAttr = builder.getUnitAttr();
48   SmallVector<NamedAttribute, 6> flags;
49   auto addFlag = [&](StringRef flag) {
50     flags.push_back(builder.getNamedAttr(flag, unitAttr));
51   };
52   if (!wave64Flag)
53     addFlag("no_wave64");
54   if (fastFlag)
55     addFlag("fast");
56   if (dazFlag)
57     addFlag("daz");
58   if (finiteOnlyFlag)
59     addFlag("finite_only");
60   if (unsafeMathFlag)
61     addFlag("unsafe_math");
62   if (!correctSqrtFlag)
63     addFlag("unsafe_sqrt");
64   if (!flags.empty())
65     return builder.getDictionaryAttr(flags);
66   return nullptr;
67 }
68 
69 void ROCDLAttachTarget::runOnOperation() {
70   OpBuilder builder(&getContext());
71   ArrayRef<std::string> libs(linkLibs);
72   SmallVector<StringRef> filesToLink(libs);
73   auto target = builder.getAttr<ROCDLTargetAttr>(
74       optLevel, triple, chip, features, abiVersion, getFlags(builder),
75       filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink));
76   llvm::Regex matcher(moduleMatcher);
77   for (Region &region : getOperation()->getRegions())
78     for (Block &block : region.getBlocks())
79       for (auto module : block.getOps<gpu::GPUModuleOp>()) {
80         // Check if the name of the module matches.
81         if (!moduleMatcher.empty() && !matcher.match(module.getName()))
82           continue;
83         // Create the target array.
84         SmallVector<Attribute> targets;
85         if (std::optional<ArrayAttr> attrs = module.getTargets())
86           targets.append(attrs->getValue().begin(), attrs->getValue().end());
87         targets.push_back(target);
88         // Remove any duplicate targets.
89         targets.erase(llvm::unique(targets), targets.end());
90         // Update the target attribute array.
91         module.setTargetsAttr(builder.getArrayAttr(targets));
92       }
93 }
94