xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace spirv {
29 #define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31 } // namespace spirv
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 /// Pass to deduce minimal version/extension/capability requirements for a
38 /// spirv::ModuleOp.
39 class UpdateVCEPass final
40     : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
41   void runOnOperation() override;
42 };
43 } // namespace
44 
45 /// Checks that `candidates` extension requirements are possible to be satisfied
46 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
47 /// errors attaching to the given `op` on failures.
48 ///
49 ///  `candidates` is a vector of vector for extension requirements following
50 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
51 /// convention.
checkAndUpdateExtensionRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates,SetVector<spirv::Extension> & deducedExtensions)52 static LogicalResult checkAndUpdateExtensionRequirements(
53     Operation *op, const spirv::TargetEnv &targetEnv,
54     const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
55     SetVector<spirv::Extension> &deducedExtensions) {
56   for (const auto &ors : candidates) {
57     if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
58       deducedExtensions.insert(*chosen);
59     } else {
60       SmallVector<StringRef, 4> extStrings;
61       for (spirv::Extension ext : ors)
62         extStrings.push_back(spirv::stringifyExtension(ext));
63 
64       return op->emitError("'")
65              << op->getName() << "' requires at least one extension in ["
66              << llvm::join(extStrings, ", ")
67              << "] but none allowed in target environment";
68     }
69   }
70   return success();
71 }
72 
73 /// Checks that `candidates`capability requirements are possible to be satisfied
74 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
75 /// errors attaching to the given `op` on failures.
76 ///
77 ///  `candidates` is a vector of vector for capability requirements following
78 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
79 /// convention.
checkAndUpdateCapabilityRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates,SetVector<spirv::Capability> & deducedCapabilities)80 static LogicalResult checkAndUpdateCapabilityRequirements(
81     Operation *op, const spirv::TargetEnv &targetEnv,
82     const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
83     SetVector<spirv::Capability> &deducedCapabilities) {
84   for (const auto &ors : candidates) {
85     if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
86       deducedCapabilities.insert(*chosen);
87     } else {
88       SmallVector<StringRef, 4> capStrings;
89       for (spirv::Capability cap : ors)
90         capStrings.push_back(spirv::stringifyCapability(cap));
91 
92       return op->emitError("'")
93              << op->getName() << "' requires at least one capability in ["
94              << llvm::join(capStrings, ", ")
95              << "] but none allowed in target environment";
96     }
97   }
98   return success();
99 }
100 
runOnOperation()101 void UpdateVCEPass::runOnOperation() {
102   spirv::ModuleOp module = getOperation();
103 
104   spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
105   if (!targetAttr) {
106     module.emitError("missing 'spirv.target_env' attribute");
107     return signalPassFailure();
108   }
109 
110   spirv::TargetEnv targetEnv(targetAttr);
111   spirv::Version allowedVersion = targetAttr.getVersion();
112 
113   spirv::Version deducedVersion = spirv::Version::V_1_0;
114   SetVector<spirv::Extension> deducedExtensions;
115   SetVector<spirv::Capability> deducedCapabilities;
116 
117   // Walk each SPIR-V op to deduce the minimal version/extension/capability
118   // requirements.
119   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
120     // Op min version requirements
121     if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
122       std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
123       if (minVersion) {
124         deducedVersion = std::max(deducedVersion, *minVersion);
125         if (deducedVersion > allowedVersion) {
126           return op->emitError("'")
127                  << op->getName() << "' requires min version "
128                  << spirv::stringifyVersion(deducedVersion)
129                  << " but target environment allows up to "
130                  << spirv::stringifyVersion(allowedVersion);
131         }
132       }
133     }
134 
135     // Op extension requirements
136     if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
137       if (failed(checkAndUpdateExtensionRequirements(
138               op, targetEnv, extensions.getExtensions(), deducedExtensions)))
139         return WalkResult::interrupt();
140 
141     // Op capability requirements
142     if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
143       if (failed(checkAndUpdateCapabilityRequirements(
144               op, targetEnv, capabilities.getCapabilities(),
145               deducedCapabilities)))
146         return WalkResult::interrupt();
147 
148     SmallVector<Type, 4> valueTypes;
149     valueTypes.append(op->operand_type_begin(), op->operand_type_end());
150     valueTypes.append(op->result_type_begin(), op->result_type_end());
151 
152     // Special treatment for global variables, whose type requirements are
153     // conveyed by type attributes.
154     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
155       valueTypes.push_back(globalVar.getType());
156 
157     // Requirements from values' types
158     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
159     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
160     for (Type valueType : valueTypes) {
161       typeExtensions.clear();
162       cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
163       if (failed(checkAndUpdateExtensionRequirements(
164               op, targetEnv, typeExtensions, deducedExtensions)))
165         return WalkResult::interrupt();
166 
167       typeCapabilities.clear();
168       cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
169       if (failed(checkAndUpdateCapabilityRequirements(
170               op, targetEnv, typeCapabilities, deducedCapabilities)))
171         return WalkResult::interrupt();
172     }
173 
174     return WalkResult::advance();
175   });
176 
177   if (walkResult.wasInterrupted())
178     return signalPassFailure();
179 
180   // TODO: verify that the deduced version is consistent with
181   // SPIR-V ops' maximal version requirements.
182 
183   auto triple = spirv::VerCapExtAttr::get(
184       deducedVersion, deducedCapabilities.getArrayRef(),
185       deducedExtensions.getArrayRef(), &getContext());
186   module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
187 }
188