1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to lower attributes that specify the shader ABI 10 // for the functions in the generated SPIR-V module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 21 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 22 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/SetVector.h" 26 27 namespace mlir { 28 namespace spirv { 29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS 30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" 31 } // namespace spirv 32 } // namespace mlir 33 34 using namespace mlir; 35 36 /// Creates a global variable for an argument based on the ABI info. 37 static spirv::GlobalVariableOp 38 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, 39 unsigned argIndex, 40 spirv::InterfaceVarABIAttr abiInfo) { 41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>(); 42 if (!spirvModule) 43 return nullptr; 44 45 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 46 builder.setInsertionPoint(funcOp.getOperation()); 47 std::string varName = 48 funcOp.getName().str() + "_arg_" + std::to_string(argIndex); 49 50 // Get the type of variable. If this is a scalar/vector type and has an ABI 51 // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If 52 // not it must already be a !spirv.ptr<!spirv.struct<...>>. 53 auto varType = funcOp.getFunctionType().getInput(argIndex); 54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) { 55 auto storageClass = abiInfo.getStorageClass(); 56 if (!storageClass) 57 return nullptr; 58 varType = 59 spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); 60 } 61 auto varPtrType = cast<spirv::PointerType>(varType); 62 auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType()); 63 64 // Set the offset information. 65 varPointeeType = 66 cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType)); 67 68 if (!varPointeeType) 69 return nullptr; 70 71 varType = 72 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); 73 74 return builder.create<spirv::GlobalVariableOp>( 75 funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(), 76 abiInfo.getBinding()); 77 } 78 79 /// Gets the global variables that need to be specified as interface variable 80 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so. 81 static LogicalResult 82 getInterfaceVariables(spirv::FuncOp funcOp, 83 SmallVectorImpl<Attribute> &interfaceVars) { 84 auto module = funcOp->getParentOfType<spirv::ModuleOp>(); 85 if (!module) { 86 return failure(); 87 } 88 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp); 89 spirv::TargetEnv targetEnv(targetEnvAttr); 90 91 SetVector<Operation *> interfaceVarSet; 92 93 // TODO: This should in reality traverse the entry function 94 // call graph and collect all the interfaces. For now, just traverse the 95 // instructions in this function. 96 funcOp.walk([&](spirv::AddressOfOp addressOfOp) { 97 auto var = 98 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable()); 99 // Per SPIR-V spec: "Before version 1.4, the interface's 100 // storage classes are limited to the Input and Output storage classes. 101 // Starting with version 1.4, the interface's storage classes are all 102 // storage classes used in declaring all global variables referenced by the 103 // entry point’s call tree." 104 const spirv::StorageClass storageClass = 105 cast<spirv::PointerType>(var.getType()).getStorageClass(); 106 if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) || 107 (llvm::is_contained( 108 {spirv::StorageClass::Input, spirv::StorageClass::Output}, 109 storageClass))) { 110 interfaceVarSet.insert(var.getOperation()); 111 } 112 }); 113 for (auto &var : interfaceVarSet) { 114 interfaceVars.push_back(SymbolRefAttr::get( 115 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName())); 116 } 117 return success(); 118 } 119 120 /// Lowers the entry point attribute. 121 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, 122 OpBuilder &builder) { 123 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 124 auto entryPointAttr = 125 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName); 126 if (!entryPointAttr) { 127 return failure(); 128 } 129 130 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp); 131 spirv::TargetEnv targetEnv(targetEnvAttr); 132 133 OpBuilder::InsertionGuard moduleInsertionGuard(builder); 134 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>(); 135 builder.setInsertionPointToEnd(spirvModule.getBody()); 136 137 // Adds the spirv.EntryPointOp after collecting all the interface variables 138 // needed. 139 SmallVector<Attribute, 1> interfaceVars; 140 if (failed(getInterfaceVariables(funcOp, interfaceVars))) { 141 return failure(); 142 } 143 144 FailureOr<spirv::ExecutionModel> executionModel = 145 spirv::getExecutionModel(targetEnvAttr); 146 if (failed(executionModel)) 147 return funcOp.emitRemark("lower entry point failure: could not select " 148 "execution model based on 'spirv.target_env'"); 149 150 builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp, 151 interfaceVars); 152 153 // Specifies the spirv.ExecutionModeOp. 154 if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) { 155 std::optional<ArrayRef<spirv::Capability>> caps = 156 spirv::getCapabilities(spirv::ExecutionMode::LocalSize); 157 if (!caps || targetEnv.allows(*caps)) { 158 builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp, 159 spirv::ExecutionMode::LocalSize, 160 workgroupSizeAttr.asArrayRef()); 161 // Erase workgroup size. 162 entryPointAttr = spirv::EntryPointABIAttr::get( 163 entryPointAttr.getContext(), DenseI32ArrayAttr(), 164 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth()); 165 } 166 } 167 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) { 168 std::optional<ArrayRef<spirv::Capability>> caps = 169 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize); 170 if (!caps || targetEnv.allows(*caps)) { 171 builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp, 172 spirv::ExecutionMode::SubgroupSize, 173 *subgroupSize); 174 // Erase subgroup size. 175 entryPointAttr = spirv::EntryPointABIAttr::get( 176 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), 177 std::nullopt, entryPointAttr.getTargetWidth()); 178 } 179 } 180 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) { 181 std::optional<ArrayRef<spirv::Capability>> caps = 182 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve); 183 if (!caps || targetEnv.allows(*caps)) { 184 builder.create<spirv::ExecutionModeOp>( 185 funcOp.getLoc(), funcOp, 186 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth); 187 // Erase target width. 188 entryPointAttr = spirv::EntryPointABIAttr::get( 189 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), 190 entryPointAttr.getSubgroupSize(), std::nullopt); 191 } 192 } 193 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() || 194 entryPointAttr.getTargetWidth()) 195 funcOp->setAttr(entryPointAttrName, entryPointAttr); 196 else 197 funcOp->removeAttr(entryPointAttrName); 198 return success(); 199 } 200 201 namespace { 202 /// A pattern to convert function signature according to interface variable ABI 203 /// attributes. 204 /// 205 /// Specifically, this pattern creates global variables according to interface 206 /// variable ABI attributes attached to function arguments and converts all 207 /// function argument uses to those global variables. This is necessary because 208 /// Vulkan requires all shader entry points to be of void(void) type. 209 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> { 210 public: 211 using OpConversionPattern<spirv::FuncOp>::OpConversionPattern; 212 213 LogicalResult 214 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, 215 ConversionPatternRewriter &rewriter) const override; 216 }; 217 218 /// Pass to implement the ABI information specified as attributes. 219 class LowerABIAttributesPass final 220 : public spirv::impl::SPIRVLowerABIAttributesPassBase< 221 LowerABIAttributesPass> { 222 void runOnOperation() override; 223 }; 224 } // namespace 225 226 LogicalResult ProcessInterfaceVarABI::matchAndRewrite( 227 spirv::FuncOp funcOp, OpAdaptor adaptor, 228 ConversionPatternRewriter &rewriter) const { 229 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>( 230 spirv::getEntryPointABIAttrName())) { 231 // TODO: Non-entry point functions are not handled. 232 return failure(); 233 } 234 TypeConverter::SignatureConversion signatureConverter( 235 funcOp.getFunctionType().getNumInputs()); 236 237 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 238 auto indexType = typeConverter.getIndexType(); 239 240 auto attrName = spirv::getInterfaceVarABIAttrName(); 241 242 OpBuilder::InsertionGuard funcInsertionGuard(rewriter); 243 rewriter.setInsertionPointToStart(&funcOp.front()); 244 245 for (const auto &argType : 246 llvm::enumerate(funcOp.getFunctionType().getInputs())) { 247 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( 248 argType.index(), attrName); 249 if (!abiInfo) { 250 // TODO: For non-entry point functions, it should be legal 251 // to pass around scalar/vector values and return a scalar/vector. For now 252 // non-entry point functions are not handled in this ABI lowering and will 253 // produce an error. 254 return failure(); 255 } 256 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument( 257 rewriter, funcOp, argType.index(), abiInfo); 258 if (!var) 259 return failure(); 260 261 // Insert spirv::AddressOf and spirv::AccessChain operations. 262 Value replacement = 263 rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var); 264 // Check if the arg is a scalar or vector type. In that case, the value 265 // needs to be loaded into registers. 266 // TODO: This is loading value of the scalar into registers 267 // at the start of the function. It is probably better to do the load just 268 // before the use. There might be multiple loads and currently there is no 269 // easy way to replace all uses with a sequence of operations. 270 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) { 271 auto zero = 272 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); 273 auto loadPtr = rewriter.create<spirv::AccessChainOp>( 274 funcOp.getLoc(), replacement, zero.getConstant()); 275 replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr); 276 } 277 signatureConverter.remapInput(argType.index(), replacement); 278 } 279 if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(), 280 &signatureConverter))) 281 return failure(); 282 283 // Creates a new function with the update signature. 284 rewriter.modifyOpInPlace(funcOp, [&] { 285 funcOp.setType(rewriter.getFunctionType( 286 signatureConverter.getConvertedTypes(), std::nullopt)); 287 }); 288 return success(); 289 } 290 291 void LowerABIAttributesPass::runOnOperation() { 292 // Uses the signature conversion methodology of the dialect conversion 293 // framework to implement the conversion. 294 spirv::ModuleOp module = getOperation(); 295 MLIRContext *context = &getContext(); 296 297 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(module); 298 if (!targetEnvAttr) { 299 module->emitOpError("missing SPIR-V target env attribute"); 300 return signalPassFailure(); 301 } 302 spirv::TargetEnv targetEnv(targetEnvAttr); 303 304 SPIRVTypeConverter typeConverter(targetEnv); 305 306 // Insert a bitcast in the case of a pointer type change. 307 typeConverter.addSourceMaterialization([](OpBuilder &builder, 308 spirv::PointerType type, 309 ValueRange inputs, Location loc) { 310 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType())) 311 return Value(); 312 return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult(); 313 }); 314 315 RewritePatternSet patterns(context); 316 patterns.add<ProcessInterfaceVarABI>(typeConverter, context); 317 318 ConversionTarget target(*context); 319 // "Legal" function ops should have no interface variable ABI attributes. 320 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) { 321 StringRef attrName = spirv::getInterfaceVarABIAttrName(); 322 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) 323 if (op.getArgAttr(i, attrName)) 324 return false; 325 return true; 326 }); 327 // All other SPIR-V ops are legal. 328 target.markUnknownOpDynamicallyLegal([](Operation *op) { 329 return op->getDialect()->getNamespace() == 330 spirv::SPIRVDialect::getDialectNamespace(); 331 }); 332 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 333 return signalPassFailure(); 334 335 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point 336 // attributes. 337 OpBuilder builder(context); 338 SmallVector<spirv::FuncOp, 1> entryPointFns; 339 auto entryPointAttrName = spirv::getEntryPointABIAttrName(); 340 module.walk([&](spirv::FuncOp funcOp) { 341 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) { 342 entryPointFns.push_back(funcOp); 343 } 344 }); 345 for (auto fn : entryPointFns) { 346 if (failed(lowerEntryPointABIAttr(fn, builder))) { 347 return signalPassFailure(); 348 } 349 } 350 } 351