xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp (revision 9dbb6e1978d3d2d61ef65c2dac1fd8add5a4c7a2)
1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/SymbolTable.h"
15 #include "mlir/Interfaces/FunctionInterfaces.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // TargetEnv
22 //===----------------------------------------------------------------------===//
23 
TargetEnv(spirv::TargetEnvAttr targetAttr)24 spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
25     : targetAttr(targetAttr) {
26   for (spirv::Extension ext : targetAttr.getExtensions())
27     givenExtensions.insert(ext);
28 
29   // Add extensions implied by the current version.
30   for (spirv::Extension ext :
31        spirv::getImpliedExtensions(targetAttr.getVersion()))
32     givenExtensions.insert(ext);
33 
34   for (spirv::Capability cap : targetAttr.getCapabilities()) {
35     givenCapabilities.insert(cap);
36 
37     // Add capabilities implied by the current capability.
38     for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
39       givenCapabilities.insert(c);
40   }
41 }
42 
getVersion() const43 spirv::Version spirv::TargetEnv::getVersion() const {
44   return targetAttr.getVersion();
45 }
46 
allows(spirv::Capability capability) const47 bool spirv::TargetEnv::allows(spirv::Capability capability) const {
48   return givenCapabilities.count(capability);
49 }
50 
51 std::optional<spirv::Capability>
allows(ArrayRef<spirv::Capability> caps) const52 spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
53   const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
54     return givenCapabilities.count(cap);
55   });
56   if (chosen != caps.end())
57     return *chosen;
58   return std::nullopt;
59 }
60 
allows(spirv::Extension extension) const61 bool spirv::TargetEnv::allows(spirv::Extension extension) const {
62   return givenExtensions.count(extension);
63 }
64 
65 std::optional<spirv::Extension>
allows(ArrayRef<spirv::Extension> exts) const66 spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
67   const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
68     return givenExtensions.count(ext);
69   });
70   if (chosen != exts.end())
71     return *chosen;
72   return std::nullopt;
73 }
74 
getVendorID() const75 spirv::Vendor spirv::TargetEnv::getVendorID() const {
76   return targetAttr.getVendorID();
77 }
78 
getDeviceType() const79 spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
80   return targetAttr.getDeviceType();
81 }
82 
getDeviceID() const83 uint32_t spirv::TargetEnv::getDeviceID() const {
84   return targetAttr.getDeviceID();
85 }
86 
getResourceLimits() const87 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
88   return targetAttr.getResourceLimits();
89 }
90 
getContext() const91 MLIRContext *spirv::TargetEnv::getContext() const {
92   return targetAttr.getContext();
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Utility functions
97 //===----------------------------------------------------------------------===//
98 
getInterfaceVarABIAttrName()99 StringRef spirv::getInterfaceVarABIAttrName() {
100   return "spirv.interface_var_abi";
101 }
102 
103 spirv::InterfaceVarABIAttr
getInterfaceVarABIAttr(unsigned descriptorSet,unsigned binding,std::optional<spirv::StorageClass> storageClass,MLIRContext * context)104 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
105                               std::optional<spirv::StorageClass> storageClass,
106                               MLIRContext *context) {
107   return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
108                                          context);
109 }
110 
needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr)111 bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
112   for (spirv::Capability cap : targetAttr.getCapabilities()) {
113     if (cap == spirv::Capability::Kernel)
114       return false;
115     if (cap == spirv::Capability::Shader)
116       return true;
117   }
118   return false;
119 }
120 
getEntryPointABIAttrName()121 StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
122 
getEntryPointABIAttr(MLIRContext * context,ArrayRef<int32_t> workgroupSize,std::optional<int> subgroupSize,std::optional<int> targetWidth)123 spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
124     MLIRContext *context, ArrayRef<int32_t> workgroupSize,
125     std::optional<int> subgroupSize, std::optional<int> targetWidth) {
126   DenseI32ArrayAttr workgroupSizeAttr;
127   if (!workgroupSize.empty()) {
128     assert(workgroupSize.size() == 3);
129     workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
130   }
131   return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
132                                        targetWidth);
133 }
134 
lookupEntryPointABI(Operation * op)135 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
136   while (op && !isa<FunctionOpInterface>(op))
137     op = op->getParentOp();
138   if (!op)
139     return {};
140 
141   if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
142           spirv::getEntryPointABIAttrName()))
143     return attr;
144 
145   return {};
146 }
147 
lookupLocalWorkGroupSize(Operation * op)148 DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
149   if (auto entryPoint = spirv::lookupEntryPointABI(op))
150     return entryPoint.getWorkgroupSize();
151 
152   return {};
153 }
154 
155 spirv::ResourceLimitsAttr
getDefaultResourceLimits(MLIRContext * context)156 spirv::getDefaultResourceLimits(MLIRContext *context) {
157   // All the fields have default values. Here we just provide a nicer way to
158   // construct a default resource limit attribute.
159   Builder b(context);
160   return spirv::ResourceLimitsAttr::get(
161       context,
162       /*max_compute_shared_memory_size=*/16384,
163       /*max_compute_workgroup_invocations=*/128,
164       /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
165       /*subgroup_size=*/32,
166       /*min_subgroup_size=*/std::nullopt,
167       /*max_subgroup_size=*/std::nullopt,
168       /*cooperative_matrix_properties_khr=*/ArrayAttr{},
169       /*cooperative_matrix_properties_nv=*/ArrayAttr{});
170 }
171 
getTargetEnvAttrName()172 StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
173 
getDefaultTargetEnv(MLIRContext * context)174 spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
175   auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
176                                           {spirv::Capability::Shader},
177                                           ArrayRef<Extension>(), context);
178   return spirv::TargetEnvAttr::get(
179       triple, spirv::getDefaultResourceLimits(context),
180       spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
181       spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
182 }
183 
lookupTargetEnv(Operation * op)184 spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
185   while (op) {
186     op = SymbolTable::getNearestSymbolTable(op);
187     if (!op)
188       break;
189 
190     if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
191             spirv::getTargetEnvAttrName()))
192       return attr;
193 
194     op = op->getParentOp();
195   }
196 
197   return {};
198 }
199 
lookupTargetEnvOrDefault(Operation * op)200 spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
201   if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
202     return attr;
203 
204   return getDefaultTargetEnv(op->getContext());
205 }
206 
207 spirv::AddressingModel
getAddressingModel(spirv::TargetEnvAttr targetAttr,bool use64bitAddress)208 spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr,
209                           bool use64bitAddress) {
210   for (spirv::Capability cap : targetAttr.getCapabilities()) {
211     if (cap == Capability::Kernel)
212       return use64bitAddress ? spirv::AddressingModel::Physical64
213                              : spirv::AddressingModel::Physical32;
214     // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
215     // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
216     // and PhysicalStorageBuffer64EXT
217     if (cap == Capability::PhysicalStorageBufferAddresses)
218       return spirv::AddressingModel::PhysicalStorageBuffer64;
219   }
220   // Logical addressing doesn't need any capabilities so return it as default.
221   return spirv::AddressingModel::Logical;
222 }
223 
224 FailureOr<spirv::ExecutionModel>
getExecutionModel(spirv::TargetEnvAttr targetAttr)225 spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) {
226   for (spirv::Capability cap : targetAttr.getCapabilities()) {
227     if (cap == spirv::Capability::Kernel)
228       return spirv::ExecutionModel::Kernel;
229     if (cap == spirv::Capability::Shader)
230       return spirv::ExecutionModel::GLCompute;
231   }
232   return failure();
233 }
234 
235 FailureOr<spirv::MemoryModel>
getMemoryModel(spirv::TargetEnvAttr targetAttr)236 spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) {
237   for (spirv::Capability cap : targetAttr.getCapabilities()) {
238     if (cap == spirv::Capability::Kernel)
239       return spirv::MemoryModel::OpenCL;
240     if (cap == spirv::Capability::Shader)
241       return spirv::MemoryModel::GLSL450;
242   }
243   return failure();
244 }
245