xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (revision b95dfa3920f71c42ef2991f42a95903cc1202c55)
1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
22 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/SetVector.h"
26 
27 namespace mlir {
28 namespace spirv {
29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31 } // namespace spirv
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 /// Creates a global variable for an argument based on the ABI info.
37 static spirv::GlobalVariableOp
38 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
39                                      unsigned argIndex,
40                                      spirv::InterfaceVarABIAttr abiInfo) {
41   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
42   if (!spirvModule)
43     return nullptr;
44 
45   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
46   builder.setInsertionPoint(funcOp.getOperation());
47   std::string varName =
48       funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
49 
50   // Get the type of variable. If this is a scalar/vector type and has an ABI
51   // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
52   // not it must already be a !spirv.ptr<!spirv.struct<...>>.
53   auto varType = funcOp.getFunctionType().getInput(argIndex);
54   if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
55     auto storageClass = abiInfo.getStorageClass();
56     if (!storageClass)
57       return nullptr;
58     varType =
59         spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
60   }
61   auto varPtrType = cast<spirv::PointerType>(varType);
62   auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
63 
64   // Set the offset information.
65   varPointeeType =
66       cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
67 
68   if (!varPointeeType)
69     return nullptr;
70 
71   varType =
72       spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
73 
74   return builder.create<spirv::GlobalVariableOp>(
75       funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
76       abiInfo.getBinding());
77 }
78 
79 /// Gets the global variables that need to be specified as interface variable
80 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
81 static LogicalResult
82 getInterfaceVariables(spirv::FuncOp funcOp,
83                       SmallVectorImpl<Attribute> &interfaceVars) {
84   auto module = funcOp->getParentOfType<spirv::ModuleOp>();
85   if (!module) {
86     return failure();
87   }
88   spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
89   spirv::TargetEnv targetEnv(targetEnvAttr);
90 
91   SetVector<Operation *> interfaceVarSet;
92 
93   // TODO: This should in reality traverse the entry function
94   // call graph and collect all the interfaces. For now, just traverse the
95   // instructions in this function.
96   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
97     auto var =
98         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
99     // Per SPIR-V spec: "Before version 1.4, the interface's
100     // storage classes are limited to the Input and Output storage classes.
101     // Starting with version 1.4, the interface's storage classes are all
102     // storage classes used in declaring all global variables referenced by the
103     // entry point’s call tree."
104     const spirv::StorageClass storageClass =
105         cast<spirv::PointerType>(var.getType()).getStorageClass();
106     if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
107         (llvm::is_contained(
108             {spirv::StorageClass::Input, spirv::StorageClass::Output},
109             storageClass))) {
110       interfaceVarSet.insert(var.getOperation());
111     }
112   });
113   for (auto &var : interfaceVarSet) {
114     interfaceVars.push_back(SymbolRefAttr::get(
115         funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
116   }
117   return success();
118 }
119 
120 /// Lowers the entry point attribute.
121 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
122                                             OpBuilder &builder) {
123   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
124   auto entryPointAttr =
125       funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
126   if (!entryPointAttr) {
127     return failure();
128   }
129 
130   spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
131   spirv::TargetEnv targetEnv(targetEnvAttr);
132 
133   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
134   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
135   builder.setInsertionPointToEnd(spirvModule.getBody());
136 
137   // Adds the spirv.EntryPointOp after collecting all the interface variables
138   // needed.
139   SmallVector<Attribute, 1> interfaceVars;
140   if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
141     return failure();
142   }
143 
144   FailureOr<spirv::ExecutionModel> executionModel =
145       spirv::getExecutionModel(targetEnvAttr);
146   if (failed(executionModel))
147     return funcOp.emitRemark("lower entry point failure: could not select "
148                              "execution model based on 'spirv.target_env'");
149 
150   builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
151                                       interfaceVars);
152 
153   // Specifies the spirv.ExecutionModeOp.
154   if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
155     std::optional<ArrayRef<spirv::Capability>> caps =
156         spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
157     if (!caps || targetEnv.allows(*caps)) {
158       builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
159                                              spirv::ExecutionMode::LocalSize,
160                                              workgroupSizeAttr.asArrayRef());
161       // Erase workgroup size.
162       entryPointAttr = spirv::EntryPointABIAttr::get(
163           entryPointAttr.getContext(), DenseI32ArrayAttr(),
164           entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
165     }
166   }
167   if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
168     std::optional<ArrayRef<spirv::Capability>> caps =
169         spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
170     if (!caps || targetEnv.allows(*caps)) {
171       builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
172                                              spirv::ExecutionMode::SubgroupSize,
173                                              *subgroupSize);
174       // Erase subgroup size.
175       entryPointAttr = spirv::EntryPointABIAttr::get(
176           entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
177           std::nullopt, entryPointAttr.getTargetWidth());
178     }
179   }
180   if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
181     std::optional<ArrayRef<spirv::Capability>> caps =
182         spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
183     if (!caps || targetEnv.allows(*caps)) {
184       builder.create<spirv::ExecutionModeOp>(
185           funcOp.getLoc(), funcOp,
186           spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
187       // Erase target width.
188       entryPointAttr = spirv::EntryPointABIAttr::get(
189           entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
190           entryPointAttr.getSubgroupSize(), std::nullopt);
191     }
192   }
193   if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
194       entryPointAttr.getTargetWidth())
195     funcOp->setAttr(entryPointAttrName, entryPointAttr);
196   else
197     funcOp->removeAttr(entryPointAttrName);
198   return success();
199 }
200 
201 namespace {
202 /// A pattern to convert function signature according to interface variable ABI
203 /// attributes.
204 ///
205 /// Specifically, this pattern creates global variables according to interface
206 /// variable ABI attributes attached to function arguments and converts all
207 /// function argument uses to those global variables. This is necessary because
208 /// Vulkan requires all shader entry points to be of void(void) type.
209 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
210 public:
211   using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
212 
213   LogicalResult
214   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
215                   ConversionPatternRewriter &rewriter) const override;
216 };
217 
218 /// Pass to implement the ABI information specified as attributes.
219 class LowerABIAttributesPass final
220     : public spirv::impl::SPIRVLowerABIAttributesPassBase<
221           LowerABIAttributesPass> {
222   void runOnOperation() override;
223 };
224 } // namespace
225 
226 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
227     spirv::FuncOp funcOp, OpAdaptor adaptor,
228     ConversionPatternRewriter &rewriter) const {
229   if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
230           spirv::getEntryPointABIAttrName())) {
231     // TODO: Non-entry point functions are not handled.
232     return failure();
233   }
234   TypeConverter::SignatureConversion signatureConverter(
235       funcOp.getFunctionType().getNumInputs());
236 
237   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
238   auto indexType = typeConverter.getIndexType();
239 
240   auto attrName = spirv::getInterfaceVarABIAttrName();
241 
242   OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
243   rewriter.setInsertionPointToStart(&funcOp.front());
244 
245   for (const auto &argType :
246        llvm::enumerate(funcOp.getFunctionType().getInputs())) {
247     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
248         argType.index(), attrName);
249     if (!abiInfo) {
250       // TODO: For non-entry point functions, it should be legal
251       // to pass around scalar/vector values and return a scalar/vector. For now
252       // non-entry point functions are not handled in this ABI lowering and will
253       // produce an error.
254       return failure();
255     }
256     spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
257         rewriter, funcOp, argType.index(), abiInfo);
258     if (!var)
259       return failure();
260 
261     // Insert spirv::AddressOf and spirv::AccessChain operations.
262     Value replacement =
263         rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
264     // Check if the arg is a scalar or vector type. In that case, the value
265     // needs to be loaded into registers.
266     // TODO: This is loading value of the scalar into registers
267     // at the start of the function. It is probably better to do the load just
268     // before the use. There might be multiple loads and currently there is no
269     // easy way to replace all uses with a sequence of operations.
270     if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
271       auto zero =
272           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
273       auto loadPtr = rewriter.create<spirv::AccessChainOp>(
274           funcOp.getLoc(), replacement, zero.getConstant());
275       replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
276     }
277     signatureConverter.remapInput(argType.index(), replacement);
278   }
279   if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
280                                          &signatureConverter)))
281     return failure();
282 
283   // Creates a new function with the update signature.
284   rewriter.modifyOpInPlace(funcOp, [&] {
285     funcOp.setType(rewriter.getFunctionType(
286         signatureConverter.getConvertedTypes(), std::nullopt));
287   });
288   return success();
289 }
290 
291 void LowerABIAttributesPass::runOnOperation() {
292   // Uses the signature conversion methodology of the dialect conversion
293   // framework to implement the conversion.
294   spirv::ModuleOp module = getOperation();
295   MLIRContext *context = &getContext();
296 
297   spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(module);
298   if (!targetEnvAttr) {
299     module->emitOpError("missing SPIR-V target env attribute");
300     return signalPassFailure();
301   }
302   spirv::TargetEnv targetEnv(targetEnvAttr);
303 
304   SPIRVTypeConverter typeConverter(targetEnv);
305 
306   // Insert a bitcast in the case of a pointer type change.
307   typeConverter.addSourceMaterialization([](OpBuilder &builder,
308                                             spirv::PointerType type,
309                                             ValueRange inputs, Location loc) {
310     if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
311       return Value();
312     return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
313   });
314 
315   RewritePatternSet patterns(context);
316   patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
317 
318   ConversionTarget target(*context);
319   // "Legal" function ops should have no interface variable ABI attributes.
320   target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
321     StringRef attrName = spirv::getInterfaceVarABIAttrName();
322     for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
323       if (op.getArgAttr(i, attrName))
324         return false;
325     return true;
326   });
327   // All other SPIR-V ops are legal.
328   target.markUnknownOpDynamicallyLegal([](Operation *op) {
329     return op->getDialect()->getNamespace() ==
330            spirv::SPIRVDialect::getDialectNamespace();
331   });
332   if (failed(applyPartialConversion(module, target, std::move(patterns))))
333     return signalPassFailure();
334 
335   // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
336   // attributes.
337   OpBuilder builder(context);
338   SmallVector<spirv::FuncOp, 1> entryPointFns;
339   auto entryPointAttrName = spirv::getEntryPointABIAttrName();
340   module.walk([&](spirv::FuncOp funcOp) {
341     if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
342       entryPointFns.push_back(funcOp);
343     }
344   });
345   for (auto fn : entryPointFns) {
346     if (failed(lowerEntryPointABIAttr(fn, builder))) {
347       return signalPassFailure();
348     }
349   }
350 }
351