19414db10SLei Zhang //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
29414db10SLei Zhang //
39414db10SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49414db10SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
59414db10SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69414db10SLei Zhang //
79414db10SLei Zhang //===----------------------------------------------------------------------===//
89414db10SLei Zhang //
99414db10SLei Zhang // This file implements a pass to deduce minimal version/extension/capability
109414db10SLei Zhang // requirements for a spirv::ModuleOp.
119414db10SLei Zhang //
129414db10SLei Zhang //===----------------------------------------------------------------------===//
139414db10SLei Zhang
1467d0d7acSMichele Scuttari #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
1567d0d7acSMichele Scuttari
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1901178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
209414db10SLei Zhang #include "mlir/IR/Builders.h"
219414db10SLei Zhang #include "mlir/IR/Visitors.h"
229414db10SLei Zhang #include "llvm/ADT/SetVector.h"
239414db10SLei Zhang #include "llvm/ADT/SmallSet.h"
24297a5b7cSNico Weber #include "llvm/ADT/StringExtras.h"
25a1fe1f5fSKazu Hirata #include <optional>
269414db10SLei Zhang
2767d0d7acSMichele Scuttari namespace mlir {
2867d0d7acSMichele Scuttari namespace spirv {
29a35a8f4bSJakub Kuderski #define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
3067d0d7acSMichele Scuttari #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
3167d0d7acSMichele Scuttari } // namespace spirv
3267d0d7acSMichele Scuttari } // namespace mlir
3367d0d7acSMichele Scuttari
349414db10SLei Zhang using namespace mlir;
359414db10SLei Zhang
369414db10SLei Zhang namespace {
379414db10SLei Zhang /// Pass to deduce minimal version/extension/capability requirements for a
389414db10SLei Zhang /// spirv::ModuleOp.
3967d0d7acSMichele Scuttari class UpdateVCEPass final
40a35a8f4bSJakub Kuderski : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
419414db10SLei Zhang void runOnOperation() override;
429414db10SLei Zhang };
439414db10SLei Zhang } // namespace
449414db10SLei Zhang
45e5c85a5aSLei Zhang /// Checks that `candidates` extension requirements are possible to be satisfied
4658df5e6dSLei Zhang /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
4758df5e6dSLei Zhang /// errors attaching to the given `op` on failures.
48e5c85a5aSLei Zhang ///
49e5c85a5aSLei Zhang /// `candidates` is a vector of vector for extension requirements following
50e5c85a5aSLei Zhang /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
51e5c85a5aSLei Zhang /// convention.
checkAndUpdateExtensionRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates,SetVector<spirv::Extension> & deducedExtensions)52e5c85a5aSLei Zhang static LogicalResult checkAndUpdateExtensionRequirements(
5358df5e6dSLei Zhang Operation *op, const spirv::TargetEnv &targetEnv,
54e5c85a5aSLei Zhang const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
554efb7754SRiver Riddle SetVector<spirv::Extension> &deducedExtensions) {
56e5c85a5aSLei Zhang for (const auto &ors : candidates) {
570a81ace0SKazu Hirata if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
58e5c85a5aSLei Zhang deducedExtensions.insert(*chosen);
59e5c85a5aSLei Zhang } else {
60e5c85a5aSLei Zhang SmallVector<StringRef, 4> extStrings;
61e5c85a5aSLei Zhang for (spirv::Extension ext : ors)
62e5c85a5aSLei Zhang extStrings.push_back(spirv::stringifyExtension(ext));
63e5c85a5aSLei Zhang
64e5c85a5aSLei Zhang return op->emitError("'")
65e5c85a5aSLei Zhang << op->getName() << "' requires at least one extension in ["
66e5c85a5aSLei Zhang << llvm::join(extStrings, ", ")
67e5c85a5aSLei Zhang << "] but none allowed in target environment";
68e5c85a5aSLei Zhang }
69e5c85a5aSLei Zhang }
70e5c85a5aSLei Zhang return success();
71e5c85a5aSLei Zhang }
72e5c85a5aSLei Zhang
73e5c85a5aSLei Zhang /// Checks that `candidates`capability requirements are possible to be satisfied
7458df5e6dSLei Zhang /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
7558df5e6dSLei Zhang /// errors attaching to the given `op` on failures.
76e5c85a5aSLei Zhang ///
77e5c85a5aSLei Zhang /// `candidates` is a vector of vector for capability requirements following
78e5c85a5aSLei Zhang /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
79e5c85a5aSLei Zhang /// convention.
checkAndUpdateCapabilityRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates,SetVector<spirv::Capability> & deducedCapabilities)80e5c85a5aSLei Zhang static LogicalResult checkAndUpdateCapabilityRequirements(
8158df5e6dSLei Zhang Operation *op, const spirv::TargetEnv &targetEnv,
82e5c85a5aSLei Zhang const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
834efb7754SRiver Riddle SetVector<spirv::Capability> &deducedCapabilities) {
84e5c85a5aSLei Zhang for (const auto &ors : candidates) {
850a81ace0SKazu Hirata if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
86e5c85a5aSLei Zhang deducedCapabilities.insert(*chosen);
87e5c85a5aSLei Zhang } else {
88e5c85a5aSLei Zhang SmallVector<StringRef, 4> capStrings;
89e5c85a5aSLei Zhang for (spirv::Capability cap : ors)
90e5c85a5aSLei Zhang capStrings.push_back(spirv::stringifyCapability(cap));
91e5c85a5aSLei Zhang
92e5c85a5aSLei Zhang return op->emitError("'")
93e5c85a5aSLei Zhang << op->getName() << "' requires at least one capability in ["
94e5c85a5aSLei Zhang << llvm::join(capStrings, ", ")
95e5c85a5aSLei Zhang << "] but none allowed in target environment";
96e5c85a5aSLei Zhang }
97e5c85a5aSLei Zhang }
98e5c85a5aSLei Zhang return success();
99e5c85a5aSLei Zhang }
100e5c85a5aSLei Zhang
runOnOperation()101039b969bSMichele Scuttari void UpdateVCEPass::runOnOperation() {
1029414db10SLei Zhang spirv::ModuleOp module = getOperation();
1039414db10SLei Zhang
10458df5e6dSLei Zhang spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
10558df5e6dSLei Zhang if (!targetAttr) {
1065ab6ef75SJakub Kuderski module.emitError("missing 'spirv.target_env' attribute");
1079414db10SLei Zhang return signalPassFailure();
1089414db10SLei Zhang }
1099414db10SLei Zhang
11058df5e6dSLei Zhang spirv::TargetEnv targetEnv(targetAttr);
11158df5e6dSLei Zhang spirv::Version allowedVersion = targetAttr.getVersion();
1129414db10SLei Zhang
1139414db10SLei Zhang spirv::Version deducedVersion = spirv::Version::V_1_0;
1144efb7754SRiver Riddle SetVector<spirv::Extension> deducedExtensions;
1154efb7754SRiver Riddle SetVector<spirv::Capability> deducedCapabilities;
1169414db10SLei Zhang
1179414db10SLei Zhang // Walk each SPIR-V op to deduce the minimal version/extension/capability
1189414db10SLei Zhang // requirements.
1199414db10SLei Zhang WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
120e5c85a5aSLei Zhang // Op min version requirements
121cb395f66SLei Zhang if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
122e8bcc37fSRamkumar Ramachandra std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
123cb395f66SLei Zhang if (minVersion) {
124cb395f66SLei Zhang deducedVersion = std::max(deducedVersion, *minVersion);
1259414db10SLei Zhang if (deducedVersion > allowedVersion) {
126cb395f66SLei Zhang return op->emitError("'")
127cb395f66SLei Zhang << op->getName() << "' requires min version "
1289414db10SLei Zhang << spirv::stringifyVersion(deducedVersion)
1299414db10SLei Zhang << " but target environment allows up to "
1309414db10SLei Zhang << spirv::stringifyVersion(allowedVersion);
1319414db10SLei Zhang }
1329414db10SLei Zhang }
133cb395f66SLei Zhang }
1349414db10SLei Zhang
135e5c85a5aSLei Zhang // Op extension requirements
136e5c85a5aSLei Zhang if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
13758df5e6dSLei Zhang if (failed(checkAndUpdateExtensionRequirements(
13858df5e6dSLei Zhang op, targetEnv, extensions.getExtensions(), deducedExtensions)))
139e5c85a5aSLei Zhang return WalkResult::interrupt();
1409414db10SLei Zhang
141e5c85a5aSLei Zhang // Op capability requirements
142e5c85a5aSLei Zhang if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
143e5c85a5aSLei Zhang if (failed(checkAndUpdateCapabilityRequirements(
14458df5e6dSLei Zhang op, targetEnv, capabilities.getCapabilities(),
145e5c85a5aSLei Zhang deducedCapabilities)))
146e5c85a5aSLei Zhang return WalkResult::interrupt();
1479414db10SLei Zhang
148e5c85a5aSLei Zhang SmallVector<Type, 4> valueTypes;
149e5c85a5aSLei Zhang valueTypes.append(op->operand_type_begin(), op->operand_type_end());
150e5c85a5aSLei Zhang valueTypes.append(op->result_type_begin(), op->result_type_end());
1519414db10SLei Zhang
152e5c85a5aSLei Zhang // Special treatment for global variables, whose type requirements are
153e5c85a5aSLei Zhang // conveyed by type attributes.
154e5c85a5aSLei Zhang if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
15590a1632dSJakub Kuderski valueTypes.push_back(globalVar.getType());
1569414db10SLei Zhang
157e5c85a5aSLei Zhang // Requirements from values' types
158e5c85a5aSLei Zhang SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
159e5c85a5aSLei Zhang SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
160e5c85a5aSLei Zhang for (Type valueType : valueTypes) {
161e5c85a5aSLei Zhang typeExtensions.clear();
162*5550c821STres Popp cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
163e5c85a5aSLei Zhang if (failed(checkAndUpdateExtensionRequirements(
16458df5e6dSLei Zhang op, targetEnv, typeExtensions, deducedExtensions)))
165e5c85a5aSLei Zhang return WalkResult::interrupt();
1669414db10SLei Zhang
167e5c85a5aSLei Zhang typeCapabilities.clear();
168*5550c821STres Popp cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
169e5c85a5aSLei Zhang if (failed(checkAndUpdateCapabilityRequirements(
17058df5e6dSLei Zhang op, targetEnv, typeCapabilities, deducedCapabilities)))
171e5c85a5aSLei Zhang return WalkResult::interrupt();
1729414db10SLei Zhang }
1739414db10SLei Zhang
1749414db10SLei Zhang return WalkResult::advance();
1759414db10SLei Zhang });
1769414db10SLei Zhang
1779414db10SLei Zhang if (walkResult.wasInterrupted())
1789414db10SLei Zhang return signalPassFailure();
1799414db10SLei Zhang
1809db53a18SRiver Riddle // TODO: verify that the deduced version is consistent with
1819414db10SLei Zhang // SPIR-V ops' maximal version requirements.
1829414db10SLei Zhang
1839414db10SLei Zhang auto triple = spirv::VerCapExtAttr::get(
1849414db10SLei Zhang deducedVersion, deducedCapabilities.getArrayRef(),
1859414db10SLei Zhang deducedExtensions.getArrayRef(), &getContext());
1861ffc1aaaSChristian Sigg module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
1879414db10SLei Zhang }
188