1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/SmallPtrSet.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/ADT/bit.h" 26 #include "llvm/Support/Debug.h" 27 #include <cstdint> 28 #include <optional> 29 30 #define DEBUG_TYPE "spirv-serialization" 31 32 using namespace mlir; 33 34 /// Returns the merge block if the given `op` is a structured control flow op. 35 /// Otherwise returns nullptr. 36 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 38 return selectionOp.getMergeBlock(); 39 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 40 return loopOp.getMergeBlock(); 41 return nullptr; 42 } 43 44 /// Given a predecessor `block` for a block with arguments, returns the block 45 /// that should be used as the parent block for SPIR-V OpPhi instructions 46 /// corresponding to the block arguments. 47 static Block *getPhiIncomingBlock(Block *block) { 48 // If the predecessor block in question is the entry block for a 49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block. 50 if (block->isEntryBlock()) { 51 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 52 // Then the incoming parent block for OpPhi should be the merge block of 53 // the structured control flow op before this loop. 54 Operation *op = loopOp.getOperation(); 55 while ((op = op->getPrevNode()) != nullptr) 56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 57 return incomingBlock; 58 // Or the enclosing block itself if no structured control flow ops 59 // exists before this loop. 60 return loopOp->getBlock(); 61 } 62 } 63 64 // Otherwise, we jump from the given predecessor block. Try to see if there is 65 // a structured control flow op inside it. 66 for (Operation &op : llvm::reverse(block->getOperations())) { 67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 68 return incomingBlock; 69 } 70 return block; 71 } 72 73 namespace mlir { 74 namespace spirv { 75 76 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 77 /// the given `binary` vector. 78 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op, 79 ArrayRef<uint32_t> operands) { 80 uint32_t wordCount = 1 + operands.size(); 81 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 82 binary.append(operands.begin(), operands.end()); 83 } 84 85 Serializer::Serializer(spirv::ModuleOp module, 86 const SerializationOptions &options) 87 : module(module), mlirBuilder(module.getContext()), options(options) {} 88 89 LogicalResult Serializer::serialize() { 90 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 91 92 if (failed(module.verifyInvariants())) 93 return failure(); 94 95 // TODO: handle the other sections 96 processCapability(); 97 processExtension(); 98 processMemoryModel(); 99 processDebugInfo(); 100 101 // Iterate over the module body to serialize it. Assumptions are that there is 102 // only one basic block in the moduleOp 103 for (auto &op : *module.getBody()) { 104 if (failed(processOperation(&op))) { 105 return failure(); 106 } 107 } 108 109 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 110 return success(); 111 } 112 113 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 114 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 115 extensions.size() + extendedSets.size() + 116 memoryModel.size() + entryPoints.size() + 117 executionModes.size() + decorations.size() + 118 typesGlobalValues.size() + functions.size(); 119 120 binary.clear(); 121 binary.reserve(moduleSize); 122 123 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(), 124 nextID); 125 binary.append(capabilities.begin(), capabilities.end()); 126 binary.append(extensions.begin(), extensions.end()); 127 binary.append(extendedSets.begin(), extendedSets.end()); 128 binary.append(memoryModel.begin(), memoryModel.end()); 129 binary.append(entryPoints.begin(), entryPoints.end()); 130 binary.append(executionModes.begin(), executionModes.end()); 131 binary.append(debug.begin(), debug.end()); 132 binary.append(names.begin(), names.end()); 133 binary.append(decorations.begin(), decorations.end()); 134 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 135 binary.append(functions.begin(), functions.end()); 136 } 137 138 #ifndef NDEBUG 139 void Serializer::printValueIDMap(raw_ostream &os) { 140 os << "\n= Value <id> Map =\n\n"; 141 for (auto valueIDPair : valueIDMap) { 142 Value val = valueIDPair.first; 143 os << " " << val << " " 144 << "id = " << valueIDPair.second << ' '; 145 if (auto *op = val.getDefiningOp()) { 146 os << "from op '" << op->getName() << "'"; 147 } else if (auto arg = dyn_cast<BlockArgument>(val)) { 148 Block *block = arg.getOwner(); 149 os << "from argument of block " << block << ' '; 150 os << " in op '" << block->getParentOp()->getName() << "'"; 151 } 152 os << '\n'; 153 } 154 } 155 #endif 156 157 //===----------------------------------------------------------------------===// 158 // Module structure 159 //===----------------------------------------------------------------------===// 160 161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 162 auto funcID = funcIDMap.lookup(fnName); 163 if (!funcID) { 164 funcID = getNextID(); 165 funcIDMap[fnName] = funcID; 166 } 167 return funcID; 168 } 169 170 void Serializer::processCapability() { 171 for (auto cap : module.getVceTriple()->getCapabilities()) 172 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 173 {static_cast<uint32_t>(cap)}); 174 } 175 176 void Serializer::processDebugInfo() { 177 if (!options.emitDebugInfo) 178 return; 179 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc()); 180 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; 181 fileID = getNextID(); 182 SmallVector<uint32_t, 16> operands; 183 operands.push_back(fileID); 184 spirv::encodeStringLiteralInto(operands, fileName); 185 encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 186 // TODO: Encode more debug instructions. 187 } 188 189 void Serializer::processExtension() { 190 llvm::SmallVector<uint32_t, 16> extName; 191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) { 192 extName.clear(); 193 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); 194 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); 195 } 196 } 197 198 void Serializer::processMemoryModel() { 199 StringAttr memoryModelName = module.getMemoryModelAttrName(); 200 auto mm = static_cast<uint32_t>( 201 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName) 202 .getValue()); 203 204 StringAttr addressingModelName = module.getAddressingModelAttrName(); 205 auto am = static_cast<uint32_t>( 206 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName) 207 .getValue()); 208 209 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); 210 } 211 212 static std::string getDecorationName(StringRef attrName) { 213 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of 214 // expected FPFastMathMode. 215 if (attrName == "fp_fast_math_mode") 216 return "FPFastMathMode"; 217 // similar here 218 if (attrName == "fp_rounding_mode") 219 return "FPRoundingMode"; 220 // convertToCamelFromSnakeCase will not capitalize "INTEL". 221 if (attrName == "cache_control_load_intel") 222 return "CacheControlLoadINTEL"; 223 if (attrName == "cache_control_store_intel") 224 return "CacheControlStoreINTEL"; 225 226 return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true); 227 } 228 229 template <typename AttrTy, typename EmitF> 230 LogicalResult processDecorationList(Location loc, Decoration decoration, 231 Attribute attrList, StringRef attrName, 232 EmitF emitter) { 233 auto arrayAttr = dyn_cast<ArrayAttr>(attrList); 234 if (!arrayAttr) { 235 return emitError(loc, "expecting array attribute of ") 236 << attrName << " for " << stringifyDecoration(decoration); 237 } 238 if (arrayAttr.empty()) { 239 return emitError(loc, "expecting non-empty array attribute of ") 240 << attrName << " for " << stringifyDecoration(decoration); 241 } 242 for (Attribute attr : arrayAttr.getValue()) { 243 auto cacheControlAttr = dyn_cast<AttrTy>(attr); 244 if (!cacheControlAttr) { 245 return emitError(loc, "expecting array attribute of ") 246 << attrName << " for " << stringifyDecoration(decoration); 247 } 248 // This named attribute encodes several decorations. Emit one per 249 // element in the array. 250 if (failed(emitter(cacheControlAttr))) 251 return failure(); 252 } 253 return success(); 254 } 255 256 LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, 257 Decoration decoration, 258 Attribute attr) { 259 SmallVector<uint32_t, 1> args; 260 switch (decoration) { 261 case spirv::Decoration::LinkageAttributes: { 262 // Get the value of the Linkage Attributes 263 // e.g., LinkageAttributes=["linkageName", linkageType]. 264 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr); 265 auto linkageName = linkageAttr.getLinkageName(); 266 auto linkageType = linkageAttr.getLinkageType().getValue(); 267 // Encode the Linkage Name (string literal to uint32_t). 268 spirv::encodeStringLiteralInto(args, linkageName); 269 // Encode LinkageType & Add the Linkagetype to the args. 270 args.push_back(static_cast<uint32_t>(linkageType)); 271 break; 272 } 273 case spirv::Decoration::FPFastMathMode: 274 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) { 275 args.push_back(static_cast<uint32_t>(intAttr.getValue())); 276 break; 277 } 278 return emitError(loc, "expected FPFastMathModeAttr attribute for ") 279 << stringifyDecoration(decoration); 280 case spirv::Decoration::FPRoundingMode: 281 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) { 282 args.push_back(static_cast<uint32_t>(intAttr.getValue())); 283 break; 284 } 285 return emitError(loc, "expected FPRoundingModeAttr attribute for ") 286 << stringifyDecoration(decoration); 287 case spirv::Decoration::Binding: 288 case spirv::Decoration::DescriptorSet: 289 case spirv::Decoration::Location: 290 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { 291 args.push_back(intAttr.getValue().getZExtValue()); 292 break; 293 } 294 return emitError(loc, "expected integer attribute for ") 295 << stringifyDecoration(decoration); 296 case spirv::Decoration::BuiltIn: 297 if (auto strAttr = dyn_cast<StringAttr>(attr)) { 298 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 299 if (enumVal) { 300 args.push_back(static_cast<uint32_t>(*enumVal)); 301 break; 302 } 303 return emitError(loc, "invalid ") 304 << stringifyDecoration(decoration) << " decoration attribute " 305 << strAttr.getValue(); 306 } 307 return emitError(loc, "expected string attribute for ") 308 << stringifyDecoration(decoration); 309 case spirv::Decoration::Aliased: 310 case spirv::Decoration::AliasedPointer: 311 case spirv::Decoration::Flat: 312 case spirv::Decoration::NonReadable: 313 case spirv::Decoration::NonWritable: 314 case spirv::Decoration::NoPerspective: 315 case spirv::Decoration::NoSignedWrap: 316 case spirv::Decoration::NoUnsignedWrap: 317 case spirv::Decoration::RelaxedPrecision: 318 case spirv::Decoration::Restrict: 319 case spirv::Decoration::RestrictPointer: 320 case spirv::Decoration::NoContraction: 321 case spirv::Decoration::Constant: 322 // For unit attributes and decoration attributes, the args list 323 // has no values so we do nothing. 324 if (isa<UnitAttr, DecorationAttr>(attr)) 325 break; 326 return emitError(loc, 327 "expected unit attribute or decoration attribute for ") 328 << stringifyDecoration(decoration); 329 case spirv::Decoration::CacheControlLoadINTEL: 330 return processDecorationList<CacheControlLoadINTELAttr>( 331 loc, decoration, attr, "CacheControlLoadINTEL", 332 [&](CacheControlLoadINTELAttr attr) { 333 unsigned cacheLevel = attr.getCacheLevel(); 334 LoadCacheControl loadCacheControl = attr.getLoadCacheControl(); 335 return emitDecoration( 336 resultID, decoration, 337 {cacheLevel, static_cast<uint32_t>(loadCacheControl)}); 338 }); 339 case spirv::Decoration::CacheControlStoreINTEL: 340 return processDecorationList<CacheControlStoreINTELAttr>( 341 loc, decoration, attr, "CacheControlStoreINTEL", 342 [&](CacheControlStoreINTELAttr attr) { 343 unsigned cacheLevel = attr.getCacheLevel(); 344 StoreCacheControl storeCacheControl = attr.getStoreCacheControl(); 345 return emitDecoration( 346 resultID, decoration, 347 {cacheLevel, static_cast<uint32_t>(storeCacheControl)}); 348 }); 349 default: 350 return emitError(loc, "unhandled decoration ") 351 << stringifyDecoration(decoration); 352 } 353 return emitDecoration(resultID, decoration, args); 354 } 355 356 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 357 NamedAttribute attr) { 358 StringRef attrName = attr.getName().strref(); 359 std::string decorationName = getDecorationName(attrName); 360 std::optional<Decoration> decoration = 361 spirv::symbolizeDecoration(decorationName); 362 if (!decoration) { 363 return emitError( 364 loc, "non-argument attributes expected to have snake-case-ified " 365 "decoration name, unhandled attribute with name : ") 366 << attrName; 367 } 368 return processDecorationAttr(loc, resultID, *decoration, attr.getValue()); 369 } 370 371 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 372 assert(!name.empty() && "unexpected empty string for OpName"); 373 if (!options.emitSymbolName) 374 return success(); 375 376 SmallVector<uint32_t, 4> nameOperands; 377 nameOperands.push_back(resultID); 378 spirv::encodeStringLiteralInto(nameOperands, name); 379 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 380 return success(); 381 } 382 383 template <> 384 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 385 Location loc, spirv::ArrayType type, uint32_t resultID) { 386 if (unsigned stride = type.getArrayStride()) { 387 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 388 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 389 } 390 return success(); 391 } 392 393 template <> 394 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 395 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { 396 if (unsigned stride = type.getArrayStride()) { 397 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 398 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 399 } 400 return success(); 401 } 402 403 LogicalResult Serializer::processMemberDecoration( 404 uint32_t structID, 405 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 406 SmallVector<uint32_t, 4> args( 407 {structID, memberDecoration.memberIndex, 408 static_cast<uint32_t>(memberDecoration.decoration)}); 409 if (memberDecoration.hasValue) { 410 args.push_back(memberDecoration.decorationValue); 411 } 412 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); 413 return success(); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // Type 418 //===----------------------------------------------------------------------===// 419 420 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 421 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 422 // PushConstant Storage Classes must be explicitly laid out." 423 bool Serializer::isInterfaceStructPtrType(Type type) const { 424 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) { 425 switch (ptrType.getStorageClass()) { 426 case spirv::StorageClass::PhysicalStorageBuffer: 427 case spirv::StorageClass::PushConstant: 428 case spirv::StorageClass::StorageBuffer: 429 case spirv::StorageClass::Uniform: 430 return isa<spirv::StructType>(ptrType.getPointeeType()); 431 default: 432 break; 433 } 434 } 435 return false; 436 } 437 438 LogicalResult Serializer::processType(Location loc, Type type, 439 uint32_t &typeID) { 440 // Maintains a set of names for nested identified struct types. This is used 441 // to properly serialize recursive references. 442 SetVector<StringRef> serializationCtx; 443 return processTypeImpl(loc, type, typeID, serializationCtx); 444 } 445 446 LogicalResult 447 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 448 SetVector<StringRef> &serializationCtx) { 449 typeID = getTypeID(type); 450 if (typeID) 451 return success(); 452 453 typeID = getNextID(); 454 SmallVector<uint32_t, 4> operands; 455 456 operands.push_back(typeID); 457 auto typeEnum = spirv::Opcode::OpTypeVoid; 458 bool deferSerialization = false; 459 460 if ((isa<FunctionType>(type) && 461 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum, 462 operands))) || 463 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 464 deferSerialization, serializationCtx))) { 465 if (deferSerialization) 466 return success(); 467 468 typeIDMap[type] = typeID; 469 470 encodeInstructionInto(typesGlobalValues, typeEnum, operands); 471 472 if (recursiveStructInfos.count(type) != 0) { 473 // This recursive struct type is emitted already, now the OpTypePointer 474 // instructions referring to recursive references are emitted as well. 475 for (auto &ptrInfo : recursiveStructInfos[type]) { 476 // TODO: This might not work if more than 1 recursive reference is 477 // present in the struct. 478 SmallVector<uint32_t, 4> ptrOperands; 479 ptrOperands.push_back(ptrInfo.pointerTypeID); 480 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 481 ptrOperands.push_back(typeIDMap[type]); 482 483 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer, 484 ptrOperands); 485 } 486 487 recursiveStructInfos[type].clear(); 488 } 489 490 return success(); 491 } 492 493 return failure(); 494 } 495 496 LogicalResult Serializer::prepareBasicType( 497 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 498 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 499 SetVector<StringRef> &serializationCtx) { 500 deferSerialization = false; 501 502 if (isVoidType(type)) { 503 typeEnum = spirv::Opcode::OpTypeVoid; 504 return success(); 505 } 506 507 if (auto intType = dyn_cast<IntegerType>(type)) { 508 if (intType.getWidth() == 1) { 509 typeEnum = spirv::Opcode::OpTypeBool; 510 return success(); 511 } 512 513 typeEnum = spirv::Opcode::OpTypeInt; 514 operands.push_back(intType.getWidth()); 515 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 516 // to preserve or validate. 517 // 0 indicates unsigned, or no signedness semantics 518 // 1 indicates signed semantics." 519 operands.push_back(intType.isSigned() ? 1 : 0); 520 return success(); 521 } 522 523 if (auto floatType = dyn_cast<FloatType>(type)) { 524 typeEnum = spirv::Opcode::OpTypeFloat; 525 operands.push_back(floatType.getWidth()); 526 return success(); 527 } 528 529 if (auto vectorType = dyn_cast<VectorType>(type)) { 530 uint32_t elementTypeID = 0; 531 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 532 serializationCtx))) { 533 return failure(); 534 } 535 typeEnum = spirv::Opcode::OpTypeVector; 536 operands.push_back(elementTypeID); 537 operands.push_back(vectorType.getNumElements()); 538 return success(); 539 } 540 541 if (auto imageType = dyn_cast<spirv::ImageType>(type)) { 542 typeEnum = spirv::Opcode::OpTypeImage; 543 uint32_t sampledTypeID = 0; 544 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 545 return failure(); 546 547 llvm::append_values(operands, sampledTypeID, 548 static_cast<uint32_t>(imageType.getDim()), 549 static_cast<uint32_t>(imageType.getDepthInfo()), 550 static_cast<uint32_t>(imageType.getArrayedInfo()), 551 static_cast<uint32_t>(imageType.getSamplingInfo()), 552 static_cast<uint32_t>(imageType.getSamplerUseInfo()), 553 static_cast<uint32_t>(imageType.getImageFormat())); 554 return success(); 555 } 556 557 if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) { 558 typeEnum = spirv::Opcode::OpTypeArray; 559 uint32_t elementTypeID = 0; 560 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 561 serializationCtx))) { 562 return failure(); 563 } 564 operands.push_back(elementTypeID); 565 if (auto elementCountID = prepareConstantInt( 566 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 567 operands.push_back(elementCountID); 568 } 569 return processTypeDecoration(loc, arrayType, resultID); 570 } 571 572 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) { 573 uint32_t pointeeTypeID = 0; 574 spirv::StructType pointeeStruct = 575 dyn_cast<spirv::StructType>(ptrType.getPointeeType()); 576 577 if (pointeeStruct && pointeeStruct.isIdentified() && 578 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 579 // A recursive reference to an enclosing struct is found. 580 // 581 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 582 // class as operands. 583 SmallVector<uint32_t, 2> forwardPtrOperands; 584 forwardPtrOperands.push_back(resultID); 585 forwardPtrOperands.push_back( 586 static_cast<uint32_t>(ptrType.getStorageClass())); 587 588 encodeInstructionInto(typesGlobalValues, 589 spirv::Opcode::OpTypeForwardPointer, 590 forwardPtrOperands); 591 592 // 2. Find the pointee (enclosing) struct. 593 auto structType = spirv::StructType::getIdentified( 594 module.getContext(), pointeeStruct.getIdentifier()); 595 596 if (!structType) 597 return failure(); 598 599 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 600 // as deferred. 601 deferSerialization = true; 602 603 // 4. Record the info needed to emit the deferred OpTypePointer 604 // instruction when the enclosing struct is completely serialized. 605 recursiveStructInfos[structType].push_back( 606 {resultID, ptrType.getStorageClass()}); 607 } else { 608 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 609 serializationCtx))) 610 return failure(); 611 } 612 613 typeEnum = spirv::Opcode::OpTypePointer; 614 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 615 operands.push_back(pointeeTypeID); 616 617 if (isInterfaceStructPtrType(ptrType)) { 618 if (failed(emitDecoration(getTypeID(pointeeStruct), 619 spirv::Decoration::Block))) 620 return emitError(loc, "cannot decorate ") 621 << pointeeStruct << " with Block decoration"; 622 } 623 624 return success(); 625 } 626 627 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) { 628 uint32_t elementTypeID = 0; 629 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 630 elementTypeID, serializationCtx))) { 631 return failure(); 632 } 633 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 634 operands.push_back(elementTypeID); 635 return processTypeDecoration(loc, runtimeArrayType, resultID); 636 } 637 638 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) { 639 typeEnum = spirv::Opcode::OpTypeSampledImage; 640 uint32_t imageTypeID = 0; 641 if (failed( 642 processType(loc, sampledImageType.getImageType(), imageTypeID))) { 643 return failure(); 644 } 645 operands.push_back(imageTypeID); 646 return success(); 647 } 648 649 if (auto structType = dyn_cast<spirv::StructType>(type)) { 650 if (structType.isIdentified()) { 651 if (failed(processName(resultID, structType.getIdentifier()))) 652 return failure(); 653 serializationCtx.insert(structType.getIdentifier()); 654 } 655 656 bool hasOffset = structType.hasOffset(); 657 for (auto elementIndex : 658 llvm::seq<uint32_t>(0, structType.getNumElements())) { 659 uint32_t elementTypeID = 0; 660 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 661 elementTypeID, serializationCtx))) { 662 return failure(); 663 } 664 operands.push_back(elementTypeID); 665 if (hasOffset) { 666 // Decorate each struct member with an offset 667 spirv::StructType::MemberDecorationInfo offsetDecoration{ 668 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 669 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 670 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 671 return emitError(loc, "cannot decorate ") 672 << elementIndex << "-th member of " << structType 673 << " with its offset"; 674 } 675 } 676 } 677 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 678 structType.getMemberDecorations(memberDecorations); 679 680 for (auto &memberDecoration : memberDecorations) { 681 if (failed(processMemberDecoration(resultID, memberDecoration))) { 682 return emitError(loc, "cannot decorate ") 683 << static_cast<uint32_t>(memberDecoration.memberIndex) 684 << "-th member of " << structType << " with " 685 << stringifyDecoration(memberDecoration.decoration); 686 } 687 } 688 689 typeEnum = spirv::Opcode::OpTypeStruct; 690 691 if (structType.isIdentified()) 692 serializationCtx.remove(structType.getIdentifier()); 693 694 return success(); 695 } 696 697 if (auto cooperativeMatrixType = 698 dyn_cast<spirv::CooperativeMatrixType>(type)) { 699 uint32_t elementTypeID = 0; 700 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 701 elementTypeID, serializationCtx))) { 702 return failure(); 703 } 704 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR; 705 auto getConstantOp = [&](uint32_t id) { 706 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 707 return prepareConstantInt(loc, attr); 708 }; 709 llvm::append_values( 710 operands, elementTypeID, 711 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())), 712 getConstantOp(cooperativeMatrixType.getRows()), 713 getConstantOp(cooperativeMatrixType.getColumns()), 714 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse()))); 715 return success(); 716 } 717 718 if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) { 719 uint32_t elementTypeID = 0; 720 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 721 serializationCtx))) { 722 return failure(); 723 } 724 typeEnum = spirv::Opcode::OpTypeMatrix; 725 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns()); 726 return success(); 727 } 728 729 // TODO: Handle other types. 730 return emitError(loc, "unhandled type in serialization: ") << type; 731 } 732 733 LogicalResult 734 Serializer::prepareFunctionType(Location loc, FunctionType type, 735 spirv::Opcode &typeEnum, 736 SmallVectorImpl<uint32_t> &operands) { 737 typeEnum = spirv::Opcode::OpTypeFunction; 738 assert(type.getNumResults() <= 1 && 739 "serialization supports only a single return value"); 740 uint32_t resultID = 0; 741 if (failed(processType( 742 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 743 resultID))) { 744 return failure(); 745 } 746 operands.push_back(resultID); 747 for (auto &res : type.getInputs()) { 748 uint32_t argTypeID = 0; 749 if (failed(processType(loc, res, argTypeID))) { 750 return failure(); 751 } 752 operands.push_back(argTypeID); 753 } 754 return success(); 755 } 756 757 //===----------------------------------------------------------------------===// 758 // Constant 759 //===----------------------------------------------------------------------===// 760 761 uint32_t Serializer::prepareConstant(Location loc, Type constType, 762 Attribute valueAttr) { 763 if (auto id = prepareConstantScalar(loc, valueAttr)) { 764 return id; 765 } 766 767 // This is a composite literal. We need to handle each component separately 768 // and then emit an OpConstantComposite for the whole. 769 770 if (auto id = getConstantID(valueAttr)) { 771 return id; 772 } 773 774 uint32_t typeID = 0; 775 if (failed(processType(loc, constType, typeID))) { 776 return 0; 777 } 778 779 uint32_t resultID = 0; 780 if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) { 781 int rank = dyn_cast<ShapedType>(attr.getType()).getRank(); 782 SmallVector<uint64_t, 4> index(rank); 783 resultID = prepareDenseElementsConstant(loc, constType, attr, 784 /*dim=*/0, index); 785 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) { 786 resultID = prepareArrayConstant(loc, constType, arrayAttr); 787 } 788 789 if (resultID == 0) { 790 emitError(loc, "cannot serialize attribute: ") << valueAttr; 791 return 0; 792 } 793 794 constIDMap[valueAttr] = resultID; 795 return resultID; 796 } 797 798 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 799 ArrayAttr attr) { 800 uint32_t typeID = 0; 801 if (failed(processType(loc, constType, typeID))) { 802 return 0; 803 } 804 805 uint32_t resultID = getNextID(); 806 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 807 operands.reserve(attr.size() + 2); 808 auto elementType = cast<spirv::ArrayType>(constType).getElementType(); 809 for (Attribute elementAttr : attr) { 810 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 811 operands.push_back(elementID); 812 } else { 813 return 0; 814 } 815 } 816 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 817 encodeInstructionInto(typesGlobalValues, opcode, operands); 818 819 return resultID; 820 } 821 822 // TODO: Turn the below function into iterative function, instead of 823 // recursive function. 824 uint32_t 825 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 826 DenseElementsAttr valueAttr, int dim, 827 MutableArrayRef<uint64_t> index) { 828 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType()); 829 assert(dim <= shapedType.getRank()); 830 if (shapedType.getRank() == dim) { 831 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { 832 return attr.getType().getElementType().isInteger(1) 833 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index]) 834 : prepareConstantInt(loc, 835 attr.getValues<IntegerAttr>()[index]); 836 } 837 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { 838 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]); 839 } 840 return 0; 841 } 842 843 uint32_t typeID = 0; 844 if (failed(processType(loc, constType, typeID))) { 845 return 0; 846 } 847 848 uint32_t resultID = getNextID(); 849 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 850 operands.reserve(shapedType.getDimSize(dim) + 2); 851 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0); 852 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 853 index[dim] = i; 854 if (auto elementID = prepareDenseElementsConstant( 855 loc, elementType, valueAttr, dim + 1, index)) { 856 operands.push_back(elementID); 857 } else { 858 return 0; 859 } 860 } 861 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 862 encodeInstructionInto(typesGlobalValues, opcode, operands); 863 864 return resultID; 865 } 866 867 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 868 bool isSpec) { 869 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) { 870 return prepareConstantFp(loc, floatAttr, isSpec); 871 } 872 if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) { 873 return prepareConstantBool(loc, boolAttr, isSpec); 874 } 875 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) { 876 return prepareConstantInt(loc, intAttr, isSpec); 877 } 878 879 return 0; 880 } 881 882 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 883 bool isSpec) { 884 if (!isSpec) { 885 // We can de-duplicate normal constants, but not specialization constants. 886 if (auto id = getConstantID(boolAttr)) { 887 return id; 888 } 889 } 890 891 // Process the type for this bool literal 892 uint32_t typeID = 0; 893 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) { 894 return 0; 895 } 896 897 auto resultID = getNextID(); 898 auto opcode = boolAttr.getValue() 899 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 900 : spirv::Opcode::OpConstantTrue) 901 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 902 : spirv::Opcode::OpConstantFalse); 903 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 904 905 if (!isSpec) { 906 constIDMap[boolAttr] = resultID; 907 } 908 return resultID; 909 } 910 911 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 912 bool isSpec) { 913 if (!isSpec) { 914 // We can de-duplicate normal constants, but not specialization constants. 915 if (auto id = getConstantID(intAttr)) { 916 return id; 917 } 918 } 919 920 // Process the type for this integer literal 921 uint32_t typeID = 0; 922 if (failed(processType(loc, intAttr.getType(), typeID))) { 923 return 0; 924 } 925 926 auto resultID = getNextID(); 927 APInt value = intAttr.getValue(); 928 unsigned bitwidth = value.getBitWidth(); 929 bool isSigned = intAttr.getType().isSignedInteger(); 930 auto opcode = 931 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 932 933 switch (bitwidth) { 934 // According to SPIR-V spec, "When the type's bit width is less than 935 // 32-bits, the literal's value appears in the low-order bits of the word, 936 // and the high-order bits must be 0 for a floating-point type, or 0 for an 937 // integer type with Signedness of 0, or sign extended when Signedness 938 // is 1." 939 case 32: 940 case 16: 941 case 8: { 942 uint32_t word = 0; 943 if (isSigned) { 944 word = static_cast<int32_t>(value.getSExtValue()); 945 } else { 946 word = static_cast<uint32_t>(value.getZExtValue()); 947 } 948 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 949 } break; 950 // According to SPIR-V spec: "When the type's bit width is larger than one 951 // word, the literal’s low-order words appear first." 952 case 64: { 953 struct DoubleWord { 954 uint32_t word1; 955 uint32_t word2; 956 } words; 957 if (isSigned) { 958 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 959 } else { 960 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 961 } 962 encodeInstructionInto(typesGlobalValues, opcode, 963 {typeID, resultID, words.word1, words.word2}); 964 } break; 965 default: { 966 std::string valueStr; 967 llvm::raw_string_ostream rss(valueStr); 968 value.print(rss, /*isSigned=*/false); 969 970 emitError(loc, "cannot serialize ") 971 << bitwidth << "-bit integer literal: " << valueStr; 972 return 0; 973 } 974 } 975 976 if (!isSpec) { 977 constIDMap[intAttr] = resultID; 978 } 979 return resultID; 980 } 981 982 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 983 bool isSpec) { 984 if (!isSpec) { 985 // We can de-duplicate normal constants, but not specialization constants. 986 if (auto id = getConstantID(floatAttr)) { 987 return id; 988 } 989 } 990 991 // Process the type for this float literal 992 uint32_t typeID = 0; 993 if (failed(processType(loc, floatAttr.getType(), typeID))) { 994 return 0; 995 } 996 997 auto resultID = getNextID(); 998 APFloat value = floatAttr.getValue(); 999 APInt intValue = value.bitcastToAPInt(); 1000 1001 auto opcode = 1002 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1003 1004 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 1005 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 1006 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1007 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 1008 struct DoubleWord { 1009 uint32_t word1; 1010 uint32_t word2; 1011 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 1012 encodeInstructionInto(typesGlobalValues, opcode, 1013 {typeID, resultID, words.word1, words.word2}); 1014 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 1015 uint32_t word = 1016 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 1017 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1018 } else { 1019 std::string valueStr; 1020 llvm::raw_string_ostream rss(valueStr); 1021 value.print(rss); 1022 1023 emitError(loc, "cannot serialize ") 1024 << floatAttr.getType() << "-typed float literal: " << valueStr; 1025 return 0; 1026 } 1027 1028 if (!isSpec) { 1029 constIDMap[floatAttr] = resultID; 1030 } 1031 return resultID; 1032 } 1033 1034 //===----------------------------------------------------------------------===// 1035 // Control flow 1036 //===----------------------------------------------------------------------===// 1037 1038 uint32_t Serializer::getOrCreateBlockID(Block *block) { 1039 if (uint32_t id = getBlockID(block)) 1040 return id; 1041 return blockIDMap[block] = getNextID(); 1042 } 1043 1044 #ifndef NDEBUG 1045 void Serializer::printBlock(Block *block, raw_ostream &os) { 1046 os << "block " << block << " (id = "; 1047 if (uint32_t id = getBlockID(block)) 1048 os << id; 1049 else 1050 os << "unknown"; 1051 os << ")\n"; 1052 } 1053 #endif 1054 1055 LogicalResult 1056 Serializer::processBlock(Block *block, bool omitLabel, 1057 function_ref<LogicalResult()> emitMerge) { 1058 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 1059 LLVM_DEBUG(block->print(llvm::dbgs())); 1060 LLVM_DEBUG(llvm::dbgs() << '\n'); 1061 if (!omitLabel) { 1062 uint32_t blockID = getOrCreateBlockID(block); 1063 LLVM_DEBUG(printBlock(block, llvm::dbgs())); 1064 1065 // Emit OpLabel for this block. 1066 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); 1067 } 1068 1069 // Emit OpPhi instructions for block arguments, if any. 1070 if (failed(emitPhiForBlockArguments(block))) 1071 return failure(); 1072 1073 // If we need to emit merge instructions, it must happen in this block. Check 1074 // whether we have other structured control flow ops, which will be expanded 1075 // into multiple basic blocks. If that's the case, we need to emit the merge 1076 // right now and then create new blocks for further serialization of the ops 1077 // in this block. 1078 if (emitMerge && 1079 llvm::any_of(block->getOperations(), 1080 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) { 1081 if (failed(emitMerge())) 1082 return failure(); 1083 emitMerge = nullptr; 1084 1085 // Start a new block for further serialization. 1086 uint32_t blockID = getNextID(); 1087 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID}); 1088 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); 1089 } 1090 1091 // Process each op in this block except the terminator. 1092 for (Operation &op : llvm::drop_end(*block)) { 1093 if (failed(processOperation(&op))) 1094 return failure(); 1095 } 1096 1097 // Process the terminator. 1098 if (emitMerge) 1099 if (failed(emitMerge())) 1100 return failure(); 1101 if (failed(processOperation(&block->back()))) 1102 return failure(); 1103 1104 return success(); 1105 } 1106 1107 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 1108 // Nothing to do if this block has no arguments or it's the entry block, which 1109 // always has the same arguments as the function signature. 1110 if (block->args_empty() || block->isEntryBlock()) 1111 return success(); 1112 1113 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n"); 1114 1115 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 1116 // A SPIR-V OpPhi instruction is of the syntax: 1117 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 1118 // So we need to collect all predecessor blocks and the arguments they send 1119 // to this block. 1120 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; 1121 for (Block *mlirPredecessor : block->getPredecessors()) { 1122 auto *terminator = mlirPredecessor->getTerminator(); 1123 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor "); 1124 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs())); 1125 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n"); 1126 // The predecessor here is the immediate one according to MLIR's IR 1127 // structure. It does not directly map to the incoming parent block for the 1128 // OpPhi instructions at SPIR-V binary level. This is because structured 1129 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 1130 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block, 1131 // the branch op jumping to the OpPhi's block then resides in the previous 1132 // structured control flow op's merge block. 1133 Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor); 1134 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor "); 1135 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs())); 1136 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 1137 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands()); 1138 } else if (auto branchCondOp = 1139 dyn_cast<spirv::BranchConditionalOp>(terminator)) { 1140 std::optional<OperandRange> blockOperands; 1141 if (branchCondOp.getTrueTarget() == block) { 1142 blockOperands = branchCondOp.getTrueTargetOperands(); 1143 } else { 1144 assert(branchCondOp.getFalseTarget() == block); 1145 blockOperands = branchCondOp.getFalseTargetOperands(); 1146 } 1147 1148 assert(!blockOperands->empty() && 1149 "expected non-empty block operand range"); 1150 predecessors.emplace_back(spirvPredecessor, *blockOperands); 1151 } else { 1152 return terminator->emitError("unimplemented terminator for Phi creation"); 1153 } 1154 LLVM_DEBUG({ 1155 llvm::dbgs() << " block arguments:\n"; 1156 for (Value v : predecessors.back().second) 1157 llvm::dbgs() << " " << v << "\n"; 1158 }); 1159 } 1160 1161 // Then create OpPhi instruction for each of the block argument. 1162 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1163 BlockArgument arg = block->getArgument(argIndex); 1164 1165 // Get the type <id> and result <id> for this OpPhi instruction. 1166 uint32_t phiTypeID = 0; 1167 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1168 return failure(); 1169 uint32_t phiID = getNextID(); 1170 1171 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1172 << arg << " (id = " << phiID << ")\n"); 1173 1174 // Prepare the (value <id>, parent block <id>) pairs. 1175 SmallVector<uint32_t, 8> phiArgs; 1176 phiArgs.push_back(phiTypeID); 1177 phiArgs.push_back(phiID); 1178 1179 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1180 Value value = predecessors[predIndex].second[argIndex]; 1181 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1182 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1183 << ") value " << value << ' '); 1184 // Each pair is a value <id> ... 1185 uint32_t valueId = getValueID(value); 1186 if (valueId == 0) { 1187 // The op generating this value hasn't been visited yet so we don't have 1188 // an <id> assigned yet. Record this to fix up later. 1189 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1190 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1191 phiArgs.size()); 1192 } else { 1193 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1194 } 1195 phiArgs.push_back(valueId); 1196 // ... and a parent block <id>. 1197 phiArgs.push_back(predBlockId); 1198 } 1199 1200 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1201 valueIDMap[arg] = phiID; 1202 } 1203 1204 return success(); 1205 } 1206 1207 //===----------------------------------------------------------------------===// 1208 // Operation 1209 //===----------------------------------------------------------------------===// 1210 1211 LogicalResult Serializer::encodeExtensionInstruction( 1212 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1213 ArrayRef<uint32_t> operands) { 1214 // Check if the extension has been imported. 1215 auto &setID = extendedInstSetIDMap[extensionSetName]; 1216 if (!setID) { 1217 setID = getNextID(); 1218 SmallVector<uint32_t, 16> importOperands; 1219 importOperands.push_back(setID); 1220 spirv::encodeStringLiteralInto(importOperands, extensionSetName); 1221 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport, 1222 importOperands); 1223 } 1224 1225 // The first two operands are the result type <id> and result <id>. The set 1226 // <id> and the opcode need to be insert after this. 1227 if (operands.size() < 2) { 1228 return op->emitError("extended instructions must have a result encoding"); 1229 } 1230 SmallVector<uint32_t, 8> extInstOperands; 1231 extInstOperands.reserve(operands.size() + 2); 1232 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1233 extInstOperands.push_back(setID); 1234 extInstOperands.push_back(extensionOpcode); 1235 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1236 encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1237 extInstOperands); 1238 return success(); 1239 } 1240 1241 LogicalResult Serializer::processOperation(Operation *opInst) { 1242 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1243 1244 // First dispatch the ops that do not directly mirror an instruction from 1245 // the SPIR-V spec. 1246 return TypeSwitch<Operation *, LogicalResult>(opInst) 1247 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1248 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1249 .Case([&](spirv::BranchConditionalOp op) { 1250 return processBranchConditionalOp(op); 1251 }) 1252 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1253 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1254 .Case([&](spirv::GlobalVariableOp op) { 1255 return processGlobalVariableOp(op); 1256 }) 1257 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1258 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1259 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1260 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1261 .Case([&](spirv::SpecConstantCompositeOp op) { 1262 return processSpecConstantCompositeOp(op); 1263 }) 1264 .Case([&](spirv::SpecConstantOperationOp op) { 1265 return processSpecConstantOperationOp(op); 1266 }) 1267 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1268 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1269 1270 // Then handle all the ops that directly mirror SPIR-V instructions with 1271 // auto-generated methods. 1272 .Default( 1273 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1274 } 1275 1276 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1277 StringRef extInstSet, 1278 uint32_t opcode) { 1279 SmallVector<uint32_t, 4> operands; 1280 Location loc = op->getLoc(); 1281 1282 uint32_t resultID = 0; 1283 if (op->getNumResults() != 0) { 1284 uint32_t resultTypeID = 0; 1285 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 1286 return failure(); 1287 operands.push_back(resultTypeID); 1288 1289 resultID = getNextID(); 1290 operands.push_back(resultID); 1291 valueIDMap[op->getResult(0)] = resultID; 1292 }; 1293 1294 for (Value operand : op->getOperands()) 1295 operands.push_back(getValueID(operand)); 1296 1297 if (failed(emitDebugLine(functionBody, loc))) 1298 return failure(); 1299 1300 if (extInstSet.empty()) { 1301 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode), 1302 operands); 1303 } else { 1304 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands))) 1305 return failure(); 1306 } 1307 1308 if (op->getNumResults() != 0) { 1309 for (auto attr : op->getAttrs()) { 1310 if (failed(processDecoration(loc, resultID, attr))) 1311 return failure(); 1312 } 1313 } 1314 1315 return success(); 1316 } 1317 1318 LogicalResult Serializer::emitDecoration(uint32_t target, 1319 spirv::Decoration decoration, 1320 ArrayRef<uint32_t> params) { 1321 uint32_t wordCount = 3 + params.size(); 1322 llvm::append_values( 1323 decorations, 1324 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target, 1325 static_cast<uint32_t>(decoration)); 1326 llvm::append_range(decorations, params); 1327 return success(); 1328 } 1329 1330 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 1331 Location loc) { 1332 if (!options.emitDebugInfo) 1333 return success(); 1334 1335 if (lastProcessedWasMergeInst) { 1336 lastProcessedWasMergeInst = false; 1337 return success(); 1338 } 1339 1340 auto fileLoc = dyn_cast<FileLineColLoc>(loc); 1341 if (fileLoc) 1342 encodeInstructionInto(binary, spirv::Opcode::OpLine, 1343 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 1344 return success(); 1345 } 1346 } // namespace spirv 1347 } // namespace mlir 1348