xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp (revision b7b337fb91f9b0538fcc4467ffca7c6c71192bc9)
1 //===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `GPUSPIRVAttachTarget` pass, attaching
10 // `#spirv.target_env` attributes to GPU modules.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Target/SPIRV/Target.h"
23 #include "llvm/Support/Regex.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::spirv;
32 
33 namespace {
34 struct SPIRVAttachTarget
35     : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
36   using Base::Base;
37 
38   void runOnOperation() override;
39 
getDependentDialects__anon2d62042a0111::SPIRVAttachTarget40   void getDependentDialects(DialectRegistry &registry) const override {
41     registry.insert<spirv::SPIRVDialect>();
42   }
43 };
44 } // namespace
45 
runOnOperation()46 void SPIRVAttachTarget::runOnOperation() {
47   OpBuilder builder(&getContext());
48   auto versionSymbol = symbolizeVersion(spirvVersion);
49   if (!versionSymbol)
50     return signalPassFailure();
51   auto apiSymbol = symbolizeClientAPI(clientApi);
52   if (!apiSymbol)
53     return signalPassFailure();
54   auto vendorSymbol = symbolizeVendor(deviceVendor);
55   if (!vendorSymbol)
56     return signalPassFailure();
57   auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
58   if (!deviceTypeSymbol)
59     return signalPassFailure();
60   // Set the default device ID if none was given
61   if (!deviceId.hasValue())
62     deviceId = mlir::spirv::TargetEnvAttr::kUnknownDeviceID;
63 
64   Version version = versionSymbol.value();
65   SmallVector<Capability, 4> capabilities;
66   SmallVector<Extension, 8> extensions;
67   for (const auto &cap : spirvCapabilities) {
68     auto capSymbol = symbolizeCapability(cap);
69     if (capSymbol)
70       capabilities.push_back(capSymbol.value());
71   }
72   ArrayRef<Capability> caps(capabilities);
73   for (const auto &ext : spirvExtensions) {
74     auto extSymbol = symbolizeExtension(ext);
75     if (extSymbol)
76       extensions.push_back(extSymbol.value());
77   }
78   ArrayRef<Extension> exts(extensions);
79   VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
80   auto target = TargetEnvAttr::get(vce, getDefaultResourceLimits(&getContext()),
81                                    apiSymbol.value(), vendorSymbol.value(),
82                                    deviceTypeSymbol.value(), deviceId);
83   llvm::Regex matcher(moduleMatcher);
84   getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
85     // Check if the name of the module matches.
86     if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
87       return;
88     // Create the target array.
89     SmallVector<Attribute> targets;
90     if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
91       targets.append(attrs->getValue().begin(), attrs->getValue().end());
92     targets.push_back(target);
93     // Remove any duplicate targets.
94     targets.erase(llvm::unique(targets), targets.end());
95     // Update the target attribute array.
96     gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
97   });
98 }
99