101178654SLei Zhang //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
201178654SLei Zhang //
301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
601178654SLei Zhang //
701178654SLei Zhang //===----------------------------------------------------------------------===//
801178654SLei Zhang
901178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
10e672f512SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1101178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1201178654SLei Zhang #include "mlir/IR/Builders.h"
1301178654SLei Zhang #include "mlir/IR/Operation.h"
1401178654SLei Zhang #include "mlir/IR/SymbolTable.h"
1534a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
16a1fe1f5fSKazu Hirata #include <optional>
1701178654SLei Zhang
1801178654SLei Zhang using namespace mlir;
1901178654SLei Zhang
2001178654SLei Zhang //===----------------------------------------------------------------------===//
2101178654SLei Zhang // TargetEnv
2201178654SLei Zhang //===----------------------------------------------------------------------===//
2301178654SLei Zhang
TargetEnv(spirv::TargetEnvAttr targetAttr)2401178654SLei Zhang spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
2501178654SLei Zhang : targetAttr(targetAttr) {
2601178654SLei Zhang for (spirv::Extension ext : targetAttr.getExtensions())
2701178654SLei Zhang givenExtensions.insert(ext);
2801178654SLei Zhang
2901178654SLei Zhang // Add extensions implied by the current version.
3001178654SLei Zhang for (spirv::Extension ext :
3101178654SLei Zhang spirv::getImpliedExtensions(targetAttr.getVersion()))
3201178654SLei Zhang givenExtensions.insert(ext);
3301178654SLei Zhang
3401178654SLei Zhang for (spirv::Capability cap : targetAttr.getCapabilities()) {
3501178654SLei Zhang givenCapabilities.insert(cap);
3601178654SLei Zhang
3701178654SLei Zhang // Add capabilities implied by the current capability.
3801178654SLei Zhang for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
3901178654SLei Zhang givenCapabilities.insert(c);
4001178654SLei Zhang }
4101178654SLei Zhang }
4201178654SLei Zhang
getVersion() const4301178654SLei Zhang spirv::Version spirv::TargetEnv::getVersion() const {
4401178654SLei Zhang return targetAttr.getVersion();
4501178654SLei Zhang }
4601178654SLei Zhang
allows(spirv::Capability capability) const4701178654SLei Zhang bool spirv::TargetEnv::allows(spirv::Capability capability) const {
4801178654SLei Zhang return givenCapabilities.count(capability);
4901178654SLei Zhang }
5001178654SLei Zhang
510a81ace0SKazu Hirata std::optional<spirv::Capability>
allows(ArrayRef<spirv::Capability> caps) const5201178654SLei Zhang spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
5301178654SLei Zhang const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
5401178654SLei Zhang return givenCapabilities.count(cap);
5501178654SLei Zhang });
5601178654SLei Zhang if (chosen != caps.end())
5701178654SLei Zhang return *chosen;
581a36588eSKazu Hirata return std::nullopt;
5901178654SLei Zhang }
6001178654SLei Zhang
allows(spirv::Extension extension) const6101178654SLei Zhang bool spirv::TargetEnv::allows(spirv::Extension extension) const {
6201178654SLei Zhang return givenExtensions.count(extension);
6301178654SLei Zhang }
6401178654SLei Zhang
650a81ace0SKazu Hirata std::optional<spirv::Extension>
allows(ArrayRef<spirv::Extension> exts) const6601178654SLei Zhang spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
6701178654SLei Zhang const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
6801178654SLei Zhang return givenExtensions.count(ext);
6901178654SLei Zhang });
7001178654SLei Zhang if (chosen != exts.end())
7101178654SLei Zhang return *chosen;
721a36588eSKazu Hirata return std::nullopt;
7301178654SLei Zhang }
7401178654SLei Zhang
getVendorID() const7501178654SLei Zhang spirv::Vendor spirv::TargetEnv::getVendorID() const {
7601178654SLei Zhang return targetAttr.getVendorID();
7701178654SLei Zhang }
7801178654SLei Zhang
getDeviceType() const7901178654SLei Zhang spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
8001178654SLei Zhang return targetAttr.getDeviceType();
8101178654SLei Zhang }
8201178654SLei Zhang
getDeviceID() const8301178654SLei Zhang uint32_t spirv::TargetEnv::getDeviceID() const {
8401178654SLei Zhang return targetAttr.getDeviceID();
8501178654SLei Zhang }
8601178654SLei Zhang
getResourceLimits() const8701178654SLei Zhang spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
8801178654SLei Zhang return targetAttr.getResourceLimits();
8901178654SLei Zhang }
9001178654SLei Zhang
getContext() const9101178654SLei Zhang MLIRContext *spirv::TargetEnv::getContext() const {
9201178654SLei Zhang return targetAttr.getContext();
9301178654SLei Zhang }
9401178654SLei Zhang
9501178654SLei Zhang //===----------------------------------------------------------------------===//
9601178654SLei Zhang // Utility functions
9701178654SLei Zhang //===----------------------------------------------------------------------===//
9801178654SLei Zhang
getInterfaceVarABIAttrName()9901178654SLei Zhang StringRef spirv::getInterfaceVarABIAttrName() {
1005ab6ef75SJakub Kuderski return "spirv.interface_var_abi";
10101178654SLei Zhang }
10201178654SLei Zhang
10301178654SLei Zhang spirv::InterfaceVarABIAttr
getInterfaceVarABIAttr(unsigned descriptorSet,unsigned binding,std::optional<spirv::StorageClass> storageClass,MLIRContext * context)10401178654SLei Zhang spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
105e8bcc37fSRamkumar Ramachandra std::optional<spirv::StorageClass> storageClass,
10601178654SLei Zhang MLIRContext *context) {
10701178654SLei Zhang return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
10801178654SLei Zhang context);
10901178654SLei Zhang }
11001178654SLei Zhang
needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr)11101178654SLei Zhang bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
11201178654SLei Zhang for (spirv::Capability cap : targetAttr.getCapabilities()) {
11301178654SLei Zhang if (cap == spirv::Capability::Kernel)
11401178654SLei Zhang return false;
11501178654SLei Zhang if (cap == spirv::Capability::Shader)
11601178654SLei Zhang return true;
11701178654SLei Zhang }
11801178654SLei Zhang return false;
11901178654SLei Zhang }
12001178654SLei Zhang
getEntryPointABIAttrName()1215ab6ef75SJakub Kuderski StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
12201178654SLei Zhang
getEntryPointABIAttr(MLIRContext * context,ArrayRef<int32_t> workgroupSize,std::optional<int> subgroupSize,std::optional<int> targetWidth)123*9dbb6e19SHsiangkai Wang spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
124*9dbb6e19SHsiangkai Wang MLIRContext *context, ArrayRef<int32_t> workgroupSize,
125*9dbb6e19SHsiangkai Wang std::optional<int> subgroupSize, std::optional<int> targetWidth) {
12652ca1499SLei Zhang DenseI32ArrayAttr workgroupSizeAttr;
12752ca1499SLei Zhang if (!workgroupSize.empty()) {
12852ca1499SLei Zhang assert(workgroupSize.size() == 3);
12952ca1499SLei Zhang workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
13052ca1499SLei Zhang }
131*9dbb6e19SHsiangkai Wang return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
132*9dbb6e19SHsiangkai Wang targetWidth);
13301178654SLei Zhang }
13401178654SLei Zhang
lookupEntryPointABI(Operation * op)13501178654SLei Zhang spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
1367ceffae1SRiver Riddle while (op && !isa<FunctionOpInterface>(op))
13701178654SLei Zhang op = op->getParentOp();
13801178654SLei Zhang if (!op)
13901178654SLei Zhang return {};
14001178654SLei Zhang
14101178654SLei Zhang if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
14201178654SLei Zhang spirv::getEntryPointABIAttrName()))
14301178654SLei Zhang return attr;
14401178654SLei Zhang
14501178654SLei Zhang return {};
14601178654SLei Zhang }
14701178654SLei Zhang
lookupLocalWorkGroupSize(Operation * op)14852ca1499SLei Zhang DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
14901178654SLei Zhang if (auto entryPoint = spirv::lookupEntryPointABI(op))
15052ca1499SLei Zhang return entryPoint.getWorkgroupSize();
15101178654SLei Zhang
15201178654SLei Zhang return {};
15301178654SLei Zhang }
15401178654SLei Zhang
15501178654SLei Zhang spirv::ResourceLimitsAttr
getDefaultResourceLimits(MLIRContext * context)15601178654SLei Zhang spirv::getDefaultResourceLimits(MLIRContext *context) {
15701178654SLei Zhang // All the fields have default values. Here we just provide a nicer way to
15801178654SLei Zhang // construct a default resource limit attribute.
159a31ff0afSMogball Builder b(context);
16001178654SLei Zhang return spirv::ResourceLimitsAttr::get(
161a31ff0afSMogball context,
162a31ff0afSMogball /*max_compute_shared_memory_size=*/16384,
163a31ff0afSMogball /*max_compute_workgroup_invocations=*/128,
164a31ff0afSMogball /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
165a31ff0afSMogball /*subgroup_size=*/32,
1661a36588eSKazu Hirata /*min_subgroup_size=*/std::nullopt,
1671a36588eSKazu Hirata /*max_subgroup_size=*/std::nullopt,
168d13da154SJakub Kuderski /*cooperative_matrix_properties_khr=*/ArrayAttr{},
169d13da154SJakub Kuderski /*cooperative_matrix_properties_nv=*/ArrayAttr{});
17001178654SLei Zhang }
17101178654SLei Zhang
getTargetEnvAttrName()1725ab6ef75SJakub Kuderski StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
17301178654SLei Zhang
getDefaultTargetEnv(MLIRContext * context)17401178654SLei Zhang spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
17501178654SLei Zhang auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
17601178654SLei Zhang {spirv::Capability::Shader},
17701178654SLei Zhang ArrayRef<Extension>(), context);
178e672f512SLei Zhang return spirv::TargetEnvAttr::get(
179e672f512SLei Zhang triple, spirv::getDefaultResourceLimits(context),
180e672f512SLei Zhang spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
181e672f512SLei Zhang spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
18201178654SLei Zhang }
18301178654SLei Zhang
lookupTargetEnv(Operation * op)18401178654SLei Zhang spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
18501178654SLei Zhang while (op) {
18601178654SLei Zhang op = SymbolTable::getNearestSymbolTable(op);
18701178654SLei Zhang if (!op)
18801178654SLei Zhang break;
18901178654SLei Zhang
19001178654SLei Zhang if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
19101178654SLei Zhang spirv::getTargetEnvAttrName()))
19201178654SLei Zhang return attr;
19301178654SLei Zhang
19401178654SLei Zhang op = op->getParentOp();
19501178654SLei Zhang }
19601178654SLei Zhang
19701178654SLei Zhang return {};
19801178654SLei Zhang }
19901178654SLei Zhang
lookupTargetEnvOrDefault(Operation * op)20001178654SLei Zhang spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
20101178654SLei Zhang if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
20201178654SLei Zhang return attr;
20301178654SLei Zhang
20401178654SLei Zhang return getDefaultTargetEnv(op->getContext());
20501178654SLei Zhang }
20601178654SLei Zhang
20701178654SLei Zhang spirv::AddressingModel
getAddressingModel(spirv::TargetEnvAttr targetAttr,bool use64bitAddress)20885365b16SLei Zhang spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr,
20985365b16SLei Zhang bool use64bitAddress) {
21001178654SLei Zhang for (spirv::Capability cap : targetAttr.getCapabilities()) {
21101178654SLei Zhang if (cap == Capability::Kernel)
21285365b16SLei Zhang return use64bitAddress ? spirv::AddressingModel::Physical64
21385365b16SLei Zhang : spirv::AddressingModel::Physical32;
2148324561eSAlexander Batashev // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
2158324561eSAlexander Batashev // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
2168324561eSAlexander Batashev // and PhysicalStorageBuffer64EXT
2178324561eSAlexander Batashev if (cap == Capability::PhysicalStorageBufferAddresses)
2188324561eSAlexander Batashev return spirv::AddressingModel::PhysicalStorageBuffer64;
21901178654SLei Zhang }
22001178654SLei Zhang // Logical addressing doesn't need any capabilities so return it as default.
22101178654SLei Zhang return spirv::AddressingModel::Logical;
22201178654SLei Zhang }
22301178654SLei Zhang
22401178654SLei Zhang FailureOr<spirv::ExecutionModel>
getExecutionModel(spirv::TargetEnvAttr targetAttr)22501178654SLei Zhang spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) {
22601178654SLei Zhang for (spirv::Capability cap : targetAttr.getCapabilities()) {
22701178654SLei Zhang if (cap == spirv::Capability::Kernel)
22801178654SLei Zhang return spirv::ExecutionModel::Kernel;
22901178654SLei Zhang if (cap == spirv::Capability::Shader)
23001178654SLei Zhang return spirv::ExecutionModel::GLCompute;
23101178654SLei Zhang }
23201178654SLei Zhang return failure();
23301178654SLei Zhang }
23401178654SLei Zhang
23501178654SLei Zhang FailureOr<spirv::MemoryModel>
getMemoryModel(spirv::TargetEnvAttr targetAttr)23601178654SLei Zhang spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) {
23701178654SLei Zhang for (spirv::Capability cap : targetAttr.getCapabilities()) {
23885365b16SLei Zhang if (cap == spirv::Capability::Kernel)
23901178654SLei Zhang return spirv::MemoryModel::OpenCL;
24001178654SLei Zhang if (cap == spirv::Capability::Shader)
24101178654SLei Zhang return spirv::MemoryModel::GLSL450;
24201178654SLei Zhang }
24301178654SLei Zhang return failure();
24401178654SLei Zhang }
245