1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/IRMapping.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/Sequence.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/bit.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include <optional> 32 33 using namespace mlir; 34 35 #define DEBUG_TYPE "spirv-deserialization" 36 37 //===----------------------------------------------------------------------===// 38 // Utility Functions 39 //===----------------------------------------------------------------------===// 40 41 /// Returns true if the given `block` is a function entry block. 42 static inline bool isFnEntryBlock(Block *block) { 43 return block->isEntryBlock() && 44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // Deserializer Method Definitions 49 //===----------------------------------------------------------------------===// 50 51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, 52 MLIRContext *context) 53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), 54 module(createModuleOp()), opBuilder(module->getRegion()) 55 #ifndef NDEBUG 56 , 57 logger(llvm::dbgs()) 58 #endif 59 { 60 } 61 62 LogicalResult spirv::Deserializer::deserialize() { 63 LLVM_DEBUG({ 64 logger.resetIndent(); 65 logger.startLine() 66 << "//+++---------- start deserialization ----------+++//\n"; 67 }); 68 69 if (failed(processHeader())) 70 return failure(); 71 72 spirv::Opcode opcode = spirv::Opcode::OpNop; 73 ArrayRef<uint32_t> operands; 74 auto binarySize = binary.size(); 75 while (curOffset < binarySize) { 76 // Slice the next instruction out and populate `opcode` and `operands`. 77 // Internally this also updates `curOffset`. 78 if (failed(sliceInstruction(opcode, operands))) 79 return failure(); 80 81 if (failed(processInstruction(opcode, operands))) 82 return failure(); 83 } 84 85 assert(curOffset == binarySize && 86 "deserializer should never index beyond the binary end"); 87 88 for (auto &deferred : deferredInstructions) { 89 if (failed(processInstruction(deferred.first, deferred.second, false))) { 90 return failure(); 91 } 92 } 93 94 attachVCETriple(); 95 96 LLVM_DEBUG(logger.startLine() 97 << "//+++-------- completed deserialization --------+++//\n"); 98 return success(); 99 } 100 101 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() { 102 return std::move(module); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Module structure 107 //===----------------------------------------------------------------------===// 108 109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() { 110 OpBuilder builder(context); 111 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); 112 spirv::ModuleOp::build(builder, state); 113 return cast<spirv::ModuleOp>(Operation::create(state)); 114 } 115 116 LogicalResult spirv::Deserializer::processHeader() { 117 if (binary.size() < spirv::kHeaderWordCount) 118 return emitError(unknownLoc, 119 "SPIR-V binary module must have a 5-word header"); 120 121 if (binary[0] != spirv::kMagicNumber) 122 return emitError(unknownLoc, "incorrect magic number"); 123 124 // Version number bytes: 0 | major number | minor number | 0 125 uint32_t majorVersion = (binary[1] << 8) >> 24; 126 uint32_t minorVersion = (binary[1] << 16) >> 24; 127 if (majorVersion == 1) { 128 switch (minorVersion) { 129 #define MIN_VERSION_CASE(v) \ 130 case v: \ 131 version = spirv::Version::V_1_##v; \ 132 break 133 134 MIN_VERSION_CASE(0); 135 MIN_VERSION_CASE(1); 136 MIN_VERSION_CASE(2); 137 MIN_VERSION_CASE(3); 138 MIN_VERSION_CASE(4); 139 MIN_VERSION_CASE(5); 140 #undef MIN_VERSION_CASE 141 default: 142 return emitError(unknownLoc, "unsupported SPIR-V minor version: ") 143 << minorVersion; 144 } 145 } else { 146 return emitError(unknownLoc, "unsupported SPIR-V major version: ") 147 << majorVersion; 148 } 149 150 // TODO: generator number, bound, schema 151 curOffset = spirv::kHeaderWordCount; 152 return success(); 153 } 154 155 LogicalResult 156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { 157 if (operands.size() != 1) 158 return emitError(unknownLoc, "OpMemoryModel must have one parameter"); 159 160 auto cap = spirv::symbolizeCapability(operands[0]); 161 if (!cap) 162 return emitError(unknownLoc, "unknown capability: ") << operands[0]; 163 164 capabilities.insert(*cap); 165 return success(); 166 } 167 168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { 169 if (words.empty()) { 170 return emitError( 171 unknownLoc, 172 "OpExtension must have a literal string for the extension name"); 173 } 174 175 unsigned wordIndex = 0; 176 StringRef extName = decodeStringLiteral(words, wordIndex); 177 if (wordIndex != words.size()) 178 return emitError(unknownLoc, 179 "unexpected trailing words in OpExtension instruction"); 180 auto ext = spirv::symbolizeExtension(extName); 181 if (!ext) 182 return emitError(unknownLoc, "unknown extension: ") << extName; 183 184 extensions.insert(*ext); 185 return success(); 186 } 187 188 LogicalResult 189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { 190 if (words.size() < 2) { 191 return emitError(unknownLoc, 192 "OpExtInstImport must have a result <id> and a literal " 193 "string for the extended instruction set name"); 194 } 195 196 unsigned wordIndex = 1; 197 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); 198 if (wordIndex != words.size()) { 199 return emitError(unknownLoc, 200 "unexpected trailing words in OpExtInstImport"); 201 } 202 return success(); 203 } 204 205 void spirv::Deserializer::attachVCETriple() { 206 (*module)->setAttr( 207 spirv::ModuleOp::getVCETripleAttrName(), 208 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), 209 extensions.getArrayRef(), context)); 210 } 211 212 LogicalResult 213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { 214 if (operands.size() != 2) 215 return emitError(unknownLoc, "OpMemoryModel must have two operands"); 216 217 (*module)->setAttr( 218 module->getAddressingModelAttrName(), 219 opBuilder.getAttr<spirv::AddressingModelAttr>( 220 static_cast<spirv::AddressingModel>(operands.front()))); 221 222 (*module)->setAttr(module->getMemoryModelAttrName(), 223 opBuilder.getAttr<spirv::MemoryModelAttr>( 224 static_cast<spirv::MemoryModel>(operands.back()))); 225 226 return success(); 227 } 228 229 template <typename AttrTy, typename EnumAttrTy, typename EnumTy> 230 LogicalResult deserializeCacheControlDecoration( 231 Location loc, OpBuilder &opBuilder, 232 DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words, 233 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) { 234 if (words.size() != 4) { 235 return emitError(loc, "OpDecoration with ") 236 << decorationName << "needs a cache control integer literal and a " 237 << cacheControlKind << " cache control literal"; 238 } 239 unsigned cacheLevel = words[2]; 240 auto cacheControlAttr = static_cast<EnumTy>(words[3]); 241 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr); 242 SmallVector<Attribute> attrs; 243 if (auto attrList = 244 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol))) 245 llvm::append_range(attrs, attrList); 246 attrs.push_back(value); 247 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs)); 248 return success(); 249 } 250 251 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { 252 // TODO: This function should also be auto-generated. For now, since only a 253 // few decorations are processed/handled in a meaningful manner, going with a 254 // manual implementation. 255 if (words.size() < 2) { 256 return emitError( 257 unknownLoc, "OpDecorate must have at least result <id> and Decoration"); 258 } 259 auto decorationName = 260 stringifyDecoration(static_cast<spirv::Decoration>(words[1])); 261 if (decorationName.empty()) { 262 return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; 263 } 264 auto symbol = getSymbolDecoration(decorationName); 265 switch (static_cast<spirv::Decoration>(words[1])) { 266 case spirv::Decoration::FPFastMathMode: 267 if (words.size() != 3) { 268 return emitError(unknownLoc, "OpDecorate with ") 269 << decorationName << " needs a single integer literal"; 270 } 271 decorations[words[0]].set( 272 symbol, FPFastMathModeAttr::get(opBuilder.getContext(), 273 static_cast<FPFastMathMode>(words[2]))); 274 break; 275 case spirv::Decoration::FPRoundingMode: 276 if (words.size() != 3) { 277 return emitError(unknownLoc, "OpDecorate with ") 278 << decorationName << " needs a single integer literal"; 279 } 280 decorations[words[0]].set( 281 symbol, FPRoundingModeAttr::get(opBuilder.getContext(), 282 static_cast<FPRoundingMode>(words[2]))); 283 break; 284 case spirv::Decoration::DescriptorSet: 285 case spirv::Decoration::Binding: 286 if (words.size() != 3) { 287 return emitError(unknownLoc, "OpDecorate with ") 288 << decorationName << " needs a single integer literal"; 289 } 290 decorations[words[0]].set( 291 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 292 break; 293 case spirv::Decoration::BuiltIn: 294 if (words.size() != 3) { 295 return emitError(unknownLoc, "OpDecorate with ") 296 << decorationName << " needs a single integer literal"; 297 } 298 decorations[words[0]].set( 299 symbol, opBuilder.getStringAttr( 300 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); 301 break; 302 case spirv::Decoration::ArrayStride: 303 if (words.size() != 3) { 304 return emitError(unknownLoc, "OpDecorate with ") 305 << decorationName << " needs a single integer literal"; 306 } 307 typeDecorations[words[0]] = words[2]; 308 break; 309 case spirv::Decoration::LinkageAttributes: { 310 if (words.size() < 4) { 311 return emitError(unknownLoc, "OpDecorate with ") 312 << decorationName 313 << " needs at least 1 string and 1 integer literal"; 314 } 315 // LinkageAttributes has two parameters ["linkageName", linkageType] 316 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import 317 // "linkageName" is a stringliteral encoded as uint32_t, 318 // hence the size of name is variable length which results in words.size() 319 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or 320 // 3 + ceildiv(strlen(name), 4). 321 unsigned wordIndex = 2; 322 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str(); 323 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>( 324 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++])); 325 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>( 326 StringAttr::get(context, linkageName), linkageTypeAttr); 327 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr)); 328 break; 329 } 330 case spirv::Decoration::Aliased: 331 case spirv::Decoration::AliasedPointer: 332 case spirv::Decoration::Block: 333 case spirv::Decoration::BufferBlock: 334 case spirv::Decoration::Flat: 335 case spirv::Decoration::NonReadable: 336 case spirv::Decoration::NonWritable: 337 case spirv::Decoration::NoPerspective: 338 case spirv::Decoration::NoSignedWrap: 339 case spirv::Decoration::NoUnsignedWrap: 340 case spirv::Decoration::RelaxedPrecision: 341 case spirv::Decoration::Restrict: 342 case spirv::Decoration::RestrictPointer: 343 case spirv::Decoration::NoContraction: 344 case spirv::Decoration::Constant: 345 if (words.size() != 2) { 346 return emitError(unknownLoc, "OpDecoration with ") 347 << decorationName << "needs a single target <id>"; 348 } 349 // Block decoration does not affect spirv.struct type, but is still stored 350 // for verification. 351 // TODO: Update StructType to contain this information since 352 // it is needed for many validation rules. 353 decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); 354 break; 355 case spirv::Decoration::Location: 356 case spirv::Decoration::SpecId: 357 if (words.size() != 3) { 358 return emitError(unknownLoc, "OpDecoration with ") 359 << decorationName << "needs a single integer literal"; 360 } 361 decorations[words[0]].set( 362 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 363 break; 364 case spirv::Decoration::CacheControlLoadINTEL: { 365 LogicalResult res = deserializeCacheControlDecoration< 366 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>( 367 unknownLoc, opBuilder, decorations, words, symbol, decorationName, 368 "load"); 369 if (failed(res)) 370 return res; 371 break; 372 } 373 case spirv::Decoration::CacheControlStoreINTEL: { 374 LogicalResult res = deserializeCacheControlDecoration< 375 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>( 376 unknownLoc, opBuilder, decorations, words, symbol, decorationName, 377 "store"); 378 if (failed(res)) 379 return res; 380 break; 381 } 382 default: 383 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; 384 } 385 return success(); 386 } 387 388 LogicalResult 389 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { 390 // The binary layout of OpMemberDecorate is different comparing to OpDecorate 391 if (words.size() < 3) { 392 return emitError(unknownLoc, 393 "OpMemberDecorate must have at least 3 operands"); 394 } 395 396 auto decoration = static_cast<spirv::Decoration>(words[2]); 397 if (decoration == spirv::Decoration::Offset && words.size() != 4) { 398 return emitError(unknownLoc, 399 " missing offset specification in OpMemberDecorate with " 400 "Offset decoration"); 401 } 402 ArrayRef<uint32_t> decorationOperands; 403 if (words.size() > 3) { 404 decorationOperands = words.slice(3); 405 } 406 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; 407 return success(); 408 } 409 410 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { 411 if (words.size() < 3) { 412 return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); 413 } 414 unsigned wordIndex = 2; 415 auto name = decodeStringLiteral(words, wordIndex); 416 if (wordIndex != words.size()) { 417 return emitError(unknownLoc, 418 "unexpected trailing words in OpMemberName instruction"); 419 } 420 memberNameMap[words[0]][words[1]] = name; 421 return success(); 422 } 423 424 LogicalResult spirv::Deserializer::setFunctionArgAttrs( 425 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) { 426 if (!decorations.contains(argID)) { 427 argAttrs[argIndex] = DictionaryAttr::get(context, {}); 428 return success(); 429 } 430 431 spirv::DecorationAttr foundDecorationAttr; 432 for (NamedAttribute decAttr : decorations[argID]) { 433 for (auto decoration : 434 {spirv::Decoration::Aliased, spirv::Decoration::Restrict, 435 spirv::Decoration::AliasedPointer, 436 spirv::Decoration::RestrictPointer}) { 437 438 if (decAttr.getName() != 439 getSymbolDecoration(stringifyDecoration(decoration))) 440 continue; 441 442 if (foundDecorationAttr) 443 return emitError(unknownLoc, 444 "more than one Aliased/Restrict decorations for " 445 "function argument with result <id> ") 446 << argID; 447 448 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration); 449 break; 450 } 451 } 452 453 if (!foundDecorationAttr) 454 return emitError(unknownLoc, "unimplemented decoration support for " 455 "function argument with result <id> ") 456 << argID; 457 458 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name), 459 foundDecorationAttr); 460 argAttrs[argIndex] = DictionaryAttr::get(context, attr); 461 return success(); 462 } 463 464 LogicalResult 465 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { 466 if (curFunction) { 467 return emitError(unknownLoc, "found function inside function"); 468 } 469 470 // Get the result type 471 if (operands.size() != 4) { 472 return emitError(unknownLoc, "OpFunction must have 4 parameters"); 473 } 474 Type resultType = getType(operands[0]); 475 if (!resultType) { 476 return emitError(unknownLoc, "undefined result type from <id> ") 477 << operands[0]; 478 } 479 480 uint32_t fnID = operands[1]; 481 if (funcMap.count(fnID)) { 482 return emitError(unknownLoc, "duplicate function definition/declaration"); 483 } 484 485 auto fnControl = spirv::symbolizeFunctionControl(operands[2]); 486 if (!fnControl) { 487 return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; 488 } 489 490 Type fnType = getType(operands[3]); 491 if (!fnType || !isa<FunctionType>(fnType)) { 492 return emitError(unknownLoc, "unknown function type from <id> ") 493 << operands[3]; 494 } 495 auto functionType = cast<FunctionType>(fnType); 496 497 if ((isVoidType(resultType) && functionType.getNumResults() != 0) || 498 (functionType.getNumResults() == 1 && 499 functionType.getResult(0) != resultType)) { 500 return emitError(unknownLoc, "mismatch in function type ") 501 << functionType << " and return type " << resultType << " specified"; 502 } 503 504 std::string fnName = getFunctionSymbol(fnID); 505 auto funcOp = opBuilder.create<spirv::FuncOp>( 506 unknownLoc, fnName, functionType, fnControl.value()); 507 // Processing other function attributes. 508 if (decorations.count(fnID)) { 509 for (auto attr : decorations[fnID].getAttrs()) { 510 funcOp->setAttr(attr.getName(), attr.getValue()); 511 } 512 } 513 curFunction = funcMap[fnID] = funcOp; 514 auto *entryBlock = funcOp.addEntryBlock(); 515 LLVM_DEBUG({ 516 logger.startLine() 517 << "//===-------------------------------------------===//\n"; 518 logger.startLine() << "[fn] name: " << fnName << "\n"; 519 logger.startLine() << "[fn] type: " << fnType << "\n"; 520 logger.startLine() << "[fn] ID: " << fnID << "\n"; 521 logger.startLine() << "[fn] entry block: " << entryBlock << "\n"; 522 logger.indent(); 523 }); 524 525 SmallVector<Attribute> argAttrs; 526 argAttrs.resize(functionType.getNumInputs()); 527 528 // Parse the op argument instructions 529 if (functionType.getNumInputs()) { 530 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 531 auto argType = functionType.getInput(i); 532 spirv::Opcode opcode = spirv::Opcode::OpNop; 533 ArrayRef<uint32_t> operands; 534 if (failed(sliceInstruction(opcode, operands, 535 spirv::Opcode::OpFunctionParameter))) { 536 return failure(); 537 } 538 if (opcode != spirv::Opcode::OpFunctionParameter) { 539 return emitError( 540 unknownLoc, 541 "missing OpFunctionParameter instruction for argument ") 542 << i; 543 } 544 if (operands.size() != 2) { 545 return emitError( 546 unknownLoc, 547 "expected result type and result <id> for OpFunctionParameter"); 548 } 549 auto argDefinedType = getType(operands[0]); 550 if (!argDefinedType || argDefinedType != argType) { 551 return emitError(unknownLoc, 552 "mismatch in argument type between function type " 553 "definition ") 554 << functionType << " and argument type definition " 555 << argDefinedType << " at argument " << i; 556 } 557 if (getValue(operands[1])) { 558 return emitError(unknownLoc, "duplicate definition of result <id> ") 559 << operands[1]; 560 } 561 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) { 562 return failure(); 563 } 564 565 auto argValue = funcOp.getArgument(i); 566 valueMap[operands[1]] = argValue; 567 } 568 } 569 570 if (llvm::any_of(argAttrs, [](Attribute attr) { 571 auto argAttr = cast<DictionaryAttr>(attr); 572 return !argAttr.empty(); 573 })) 574 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs)); 575 576 // entryBlock is needed to access the arguments, Once that is done, we can 577 // erase the block for functions with 'Import' LinkageAttributes, since these 578 // are essentially function declarations, so they have no body. 579 auto linkageAttr = funcOp.getLinkageAttributes(); 580 auto hasImportLinkage = 581 linkageAttr && (linkageAttr.value().getLinkageType().getValue() == 582 spirv::LinkageType::Import); 583 if (hasImportLinkage) 584 funcOp.eraseBody(); 585 586 // RAII guard to reset the insertion point to the module's region after 587 // deserializing the body of this function. 588 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 589 590 spirv::Opcode opcode = spirv::Opcode::OpNop; 591 ArrayRef<uint32_t> instOperands; 592 593 // Special handling for the entry block. We need to make sure it starts with 594 // an OpLabel instruction. The entry block takes the same parameters as the 595 // function. All other blocks do not take any parameter. We have already 596 // created the entry block, here we need to register it to the correct label 597 // <id>. 598 if (failed(sliceInstruction(opcode, instOperands, 599 spirv::Opcode::OpFunctionEnd))) { 600 return failure(); 601 } 602 if (opcode == spirv::Opcode::OpFunctionEnd) { 603 return processFunctionEnd(instOperands); 604 } 605 if (opcode != spirv::Opcode::OpLabel) { 606 return emitError(unknownLoc, "a basic block must start with OpLabel"); 607 } 608 if (instOperands.size() != 1) { 609 return emitError(unknownLoc, "OpLabel should only have result <id>"); 610 } 611 blockMap[instOperands[0]] = entryBlock; 612 if (failed(processLabel(instOperands))) { 613 return failure(); 614 } 615 616 // Then process all the other instructions in the function until we hit 617 // OpFunctionEnd. 618 while (succeeded(sliceInstruction(opcode, instOperands, 619 spirv::Opcode::OpFunctionEnd)) && 620 opcode != spirv::Opcode::OpFunctionEnd) { 621 if (failed(processInstruction(opcode, instOperands))) { 622 return failure(); 623 } 624 } 625 if (opcode != spirv::Opcode::OpFunctionEnd) { 626 return failure(); 627 } 628 629 return processFunctionEnd(instOperands); 630 } 631 632 LogicalResult 633 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { 634 // Process OpFunctionEnd. 635 if (!operands.empty()) { 636 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); 637 } 638 639 // Wire up block arguments from OpPhi instructions. 640 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop 641 // ops. 642 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { 643 return failure(); 644 } 645 646 curBlock = nullptr; 647 curFunction = std::nullopt; 648 649 LLVM_DEBUG({ 650 logger.unindent(); 651 logger.startLine() 652 << "//===-------------------------------------------===//\n"; 653 }); 654 return success(); 655 } 656 657 std::optional<std::pair<Attribute, Type>> 658 spirv::Deserializer::getConstant(uint32_t id) { 659 auto constIt = constantMap.find(id); 660 if (constIt == constantMap.end()) 661 return std::nullopt; 662 return constIt->getSecond(); 663 } 664 665 std::optional<spirv::SpecConstOperationMaterializationInfo> 666 spirv::Deserializer::getSpecConstantOperation(uint32_t id) { 667 auto constIt = specConstOperationMap.find(id); 668 if (constIt == specConstOperationMap.end()) 669 return std::nullopt; 670 return constIt->getSecond(); 671 } 672 673 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { 674 auto funcName = nameMap.lookup(id).str(); 675 if (funcName.empty()) { 676 funcName = "spirv_fn_" + std::to_string(id); 677 } 678 return funcName; 679 } 680 681 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { 682 auto constName = nameMap.lookup(id).str(); 683 if (constName.empty()) { 684 constName = "spirv_spec_const_" + std::to_string(id); 685 } 686 return constName; 687 } 688 689 spirv::SpecConstantOp 690 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, 691 TypedAttr defaultValue) { 692 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 693 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, 694 defaultValue); 695 if (decorations.count(resultID)) { 696 for (auto attr : decorations[resultID].getAttrs()) 697 op->setAttr(attr.getName(), attr.getValue()); 698 } 699 specConstMap[resultID] = op; 700 return op; 701 } 702 703 LogicalResult 704 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { 705 unsigned wordIndex = 0; 706 if (operands.size() < 3) { 707 return emitError( 708 unknownLoc, 709 "OpVariable needs at least 3 operands, type, <id> and storage class"); 710 } 711 712 // Result Type. 713 auto type = getType(operands[wordIndex]); 714 if (!type) { 715 return emitError(unknownLoc, "unknown result type <id> : ") 716 << operands[wordIndex]; 717 } 718 auto ptrType = dyn_cast<spirv::PointerType>(type); 719 if (!ptrType) { 720 return emitError(unknownLoc, 721 "expected a result type <id> to be a spirv.ptr, found : ") 722 << type; 723 } 724 wordIndex++; 725 726 // Result <id>. 727 auto variableID = operands[wordIndex]; 728 auto variableName = nameMap.lookup(variableID).str(); 729 if (variableName.empty()) { 730 variableName = "spirv_var_" + std::to_string(variableID); 731 } 732 wordIndex++; 733 734 // Storage class. 735 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); 736 if (ptrType.getStorageClass() != storageClass) { 737 return emitError(unknownLoc, "mismatch in storage class of pointer type ") 738 << type << " and that specified in OpVariable instruction : " 739 << stringifyStorageClass(storageClass); 740 } 741 wordIndex++; 742 743 // Initializer. 744 FlatSymbolRefAttr initializer = nullptr; 745 746 if (wordIndex < operands.size()) { 747 Operation *op = nullptr; 748 749 if (auto initOp = getGlobalVariable(operands[wordIndex])) 750 op = initOp; 751 else if (auto initOp = getSpecConstant(operands[wordIndex])) 752 op = initOp; 753 else if (auto initOp = getSpecConstantComposite(operands[wordIndex])) 754 op = initOp; 755 else 756 return emitError(unknownLoc, "unknown <id> ") 757 << operands[wordIndex] << "used as initializer"; 758 759 initializer = SymbolRefAttr::get(op); 760 wordIndex++; 761 } 762 if (wordIndex != operands.size()) { 763 return emitError(unknownLoc, 764 "found more operands than expected when deserializing " 765 "OpVariable instruction, only ") 766 << wordIndex << " of " << operands.size() << " processed"; 767 } 768 auto loc = createFileLineColLoc(opBuilder); 769 auto varOp = opBuilder.create<spirv::GlobalVariableOp>( 770 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), 771 initializer); 772 773 // Decorations. 774 if (decorations.count(variableID)) { 775 for (auto attr : decorations[variableID].getAttrs()) 776 varOp->setAttr(attr.getName(), attr.getValue()); 777 } 778 globalVariableMap[variableID] = varOp; 779 return success(); 780 } 781 782 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { 783 auto constInfo = getConstant(id); 784 if (!constInfo) { 785 return nullptr; 786 } 787 return dyn_cast<IntegerAttr>(constInfo->first); 788 } 789 790 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { 791 if (operands.size() < 2) { 792 return emitError(unknownLoc, "OpName needs at least 2 operands"); 793 } 794 if (!nameMap.lookup(operands[0]).empty()) { 795 return emitError(unknownLoc, "duplicate name found for result <id> ") 796 << operands[0]; 797 } 798 unsigned wordIndex = 1; 799 StringRef name = decodeStringLiteral(operands, wordIndex); 800 if (wordIndex != operands.size()) { 801 return emitError(unknownLoc, 802 "unexpected trailing words in OpName instruction"); 803 } 804 nameMap[operands[0]] = name; 805 return success(); 806 } 807 808 //===----------------------------------------------------------------------===// 809 // Type 810 //===----------------------------------------------------------------------===// 811 812 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, 813 ArrayRef<uint32_t> operands) { 814 if (operands.empty()) { 815 return emitError(unknownLoc, "type instruction with opcode ") 816 << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; 817 } 818 819 /// TODO: Types might be forward declared in some instructions and need to be 820 /// handled appropriately. 821 if (typeMap.count(operands[0])) { 822 return emitError(unknownLoc, "duplicate definition for result <id> ") 823 << operands[0]; 824 } 825 826 switch (opcode) { 827 case spirv::Opcode::OpTypeVoid: 828 if (operands.size() != 1) 829 return emitError(unknownLoc, "OpTypeVoid must have no parameters"); 830 typeMap[operands[0]] = opBuilder.getNoneType(); 831 break; 832 case spirv::Opcode::OpTypeBool: 833 if (operands.size() != 1) 834 return emitError(unknownLoc, "OpTypeBool must have no parameters"); 835 typeMap[operands[0]] = opBuilder.getI1Type(); 836 break; 837 case spirv::Opcode::OpTypeInt: { 838 if (operands.size() != 3) 839 return emitError( 840 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); 841 842 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 843 // to preserve or validate. 844 // 0 indicates unsigned, or no signedness semantics 845 // 1 indicates signed semantics." 846 // 847 // So we cannot differentiate signless and unsigned integers; always use 848 // signless semantics for such cases. 849 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed 850 : IntegerType::SignednessSemantics::Signless; 851 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); 852 } break; 853 case spirv::Opcode::OpTypeFloat: { 854 if (operands.size() != 2) 855 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); 856 857 Type floatTy; 858 switch (operands[1]) { 859 case 16: 860 floatTy = opBuilder.getF16Type(); 861 break; 862 case 32: 863 floatTy = opBuilder.getF32Type(); 864 break; 865 case 64: 866 floatTy = opBuilder.getF64Type(); 867 break; 868 default: 869 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") 870 << operands[1]; 871 } 872 typeMap[operands[0]] = floatTy; 873 } break; 874 case spirv::Opcode::OpTypeVector: { 875 if (operands.size() != 3) { 876 return emitError( 877 unknownLoc, 878 "OpTypeVector must have element type and count parameters"); 879 } 880 Type elementTy = getType(operands[1]); 881 if (!elementTy) { 882 return emitError(unknownLoc, "OpTypeVector references undefined <id> ") 883 << operands[1]; 884 } 885 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); 886 } break; 887 case spirv::Opcode::OpTypePointer: { 888 return processOpTypePointer(operands); 889 } break; 890 case spirv::Opcode::OpTypeArray: 891 return processArrayType(operands); 892 case spirv::Opcode::OpTypeCooperativeMatrixKHR: 893 return processCooperativeMatrixTypeKHR(operands); 894 case spirv::Opcode::OpTypeFunction: 895 return processFunctionType(operands); 896 case spirv::Opcode::OpTypeImage: 897 return processImageType(operands); 898 case spirv::Opcode::OpTypeSampledImage: 899 return processSampledImageType(operands); 900 case spirv::Opcode::OpTypeRuntimeArray: 901 return processRuntimeArrayType(operands); 902 case spirv::Opcode::OpTypeStruct: 903 return processStructType(operands); 904 case spirv::Opcode::OpTypeMatrix: 905 return processMatrixType(operands); 906 default: 907 return emitError(unknownLoc, "unhandled type instruction"); 908 } 909 return success(); 910 } 911 912 LogicalResult 913 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { 914 if (operands.size() != 3) 915 return emitError(unknownLoc, "OpTypePointer must have two parameters"); 916 917 auto pointeeType = getType(operands[2]); 918 if (!pointeeType) 919 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ") 920 << operands[2]; 921 922 uint32_t typePointerID = operands[0]; 923 auto storageClass = static_cast<spirv::StorageClass>(operands[1]); 924 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); 925 926 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); 927 deferredStructIt != std::end(deferredStructTypesInfos);) { 928 for (auto *unresolvedMemberIt = 929 std::begin(deferredStructIt->unresolvedMemberTypes); 930 unresolvedMemberIt != 931 std::end(deferredStructIt->unresolvedMemberTypes);) { 932 if (unresolvedMemberIt->first == typePointerID) { 933 // The newly constructed pointer type can resolve one of the 934 // deferred struct type members; update the memberTypes list and 935 // clean the unresolvedMemberTypes list accordingly. 936 deferredStructIt->memberTypes[unresolvedMemberIt->second] = 937 typeMap[typePointerID]; 938 unresolvedMemberIt = 939 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); 940 } else { 941 ++unresolvedMemberIt; 942 } 943 } 944 945 if (deferredStructIt->unresolvedMemberTypes.empty()) { 946 // All deferred struct type members are now resolved, set the struct body. 947 auto structType = deferredStructIt->deferredStructType; 948 949 assert(structType && "expected a spirv::StructType"); 950 assert(structType.isIdentified() && "expected an indentified struct"); 951 952 if (failed(structType.trySetBody( 953 deferredStructIt->memberTypes, deferredStructIt->offsetInfo, 954 deferredStructIt->memberDecorationsInfo))) 955 return failure(); 956 957 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); 958 } else { 959 ++deferredStructIt; 960 } 961 } 962 963 return success(); 964 } 965 966 LogicalResult 967 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { 968 if (operands.size() != 3) { 969 return emitError(unknownLoc, 970 "OpTypeArray must have element type and count parameters"); 971 } 972 973 Type elementTy = getType(operands[1]); 974 if (!elementTy) { 975 return emitError(unknownLoc, "OpTypeArray references undefined <id> ") 976 << operands[1]; 977 } 978 979 unsigned count = 0; 980 // TODO: The count can also come frome a specialization constant. 981 auto countInfo = getConstant(operands[2]); 982 if (!countInfo) { 983 return emitError(unknownLoc, "OpTypeArray count <id> ") 984 << operands[2] << "can only come from normal constant right now"; 985 } 986 987 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) { 988 count = intVal.getValue().getZExtValue(); 989 } else { 990 return emitError(unknownLoc, "OpTypeArray count must come from a " 991 "scalar integer constant instruction"); 992 } 993 994 typeMap[operands[0]] = spirv::ArrayType::get( 995 elementTy, count, typeDecorations.lookup(operands[0])); 996 return success(); 997 } 998 999 LogicalResult 1000 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { 1001 assert(!operands.empty() && "No operands for processing function type"); 1002 if (operands.size() == 1) { 1003 return emitError(unknownLoc, "missing return type for OpTypeFunction"); 1004 } 1005 auto returnType = getType(operands[1]); 1006 if (!returnType) { 1007 return emitError(unknownLoc, "unknown return type in OpTypeFunction"); 1008 } 1009 SmallVector<Type, 1> argTypes; 1010 for (size_t i = 2, e = operands.size(); i < e; ++i) { 1011 auto ty = getType(operands[i]); 1012 if (!ty) { 1013 return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); 1014 } 1015 argTypes.push_back(ty); 1016 } 1017 ArrayRef<Type> returnTypes; 1018 if (!isVoidType(returnType)) { 1019 returnTypes = llvm::ArrayRef(returnType); 1020 } 1021 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); 1022 return success(); 1023 } 1024 1025 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR( 1026 ArrayRef<uint32_t> operands) { 1027 if (operands.size() != 6) { 1028 return emitError(unknownLoc, 1029 "OpTypeCooperativeMatrixKHR must have element type, " 1030 "scope, row and column parameters, and use"); 1031 } 1032 1033 Type elementTy = getType(operands[1]); 1034 if (!elementTy) { 1035 return emitError(unknownLoc, 1036 "OpTypeCooperativeMatrixKHR references undefined <id> ") 1037 << operands[1]; 1038 } 1039 1040 std::optional<spirv::Scope> scope = 1041 spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); 1042 if (!scope) { 1043 return emitError( 1044 unknownLoc, 1045 "OpTypeCooperativeMatrixKHR references undefined scope <id> ") 1046 << operands[2]; 1047 } 1048 1049 unsigned rows = getConstantInt(operands[3]).getInt(); 1050 unsigned columns = getConstantInt(operands[4]).getInt(); 1051 1052 std::optional<spirv::CooperativeMatrixUseKHR> use = 1053 spirv::symbolizeCooperativeMatrixUseKHR( 1054 getConstantInt(operands[5]).getInt()); 1055 if (!use) { 1056 return emitError( 1057 unknownLoc, 1058 "OpTypeCooperativeMatrixKHR references undefined use <id> ") 1059 << operands[5]; 1060 } 1061 1062 typeMap[operands[0]] = 1063 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use); 1064 return success(); 1065 } 1066 1067 LogicalResult 1068 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { 1069 if (operands.size() != 2) { 1070 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); 1071 } 1072 Type memberType = getType(operands[1]); 1073 if (!memberType) { 1074 return emitError(unknownLoc, 1075 "OpTypeRuntimeArray references undefined <id> ") 1076 << operands[1]; 1077 } 1078 typeMap[operands[0]] = spirv::RuntimeArrayType::get( 1079 memberType, typeDecorations.lookup(operands[0])); 1080 return success(); 1081 } 1082 1083 LogicalResult 1084 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { 1085 // TODO: Find a way to handle identified structs when debug info is stripped. 1086 1087 if (operands.empty()) { 1088 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>"); 1089 } 1090 1091 if (operands.size() == 1) { 1092 // Handle empty struct. 1093 typeMap[operands[0]] = 1094 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); 1095 return success(); 1096 } 1097 1098 // First element is operand ID, second element is member index in the struct. 1099 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; 1100 SmallVector<Type, 4> memberTypes; 1101 1102 for (auto op : llvm::drop_begin(operands, 1)) { 1103 Type memberType = getType(op); 1104 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); 1105 1106 if (!memberType && !typeForwardPtr) 1107 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") 1108 << op; 1109 1110 if (!memberType) 1111 unresolvedMemberTypes.emplace_back(op, memberTypes.size()); 1112 1113 memberTypes.push_back(memberType); 1114 } 1115 1116 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; 1117 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; 1118 if (memberDecorationMap.count(operands[0])) { 1119 auto &allMemberDecorations = memberDecorationMap[operands[0]]; 1120 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { 1121 if (allMemberDecorations.count(memberIndex)) { 1122 for (auto &memberDecoration : allMemberDecorations[memberIndex]) { 1123 // Check for offset. 1124 if (memberDecoration.first == spirv::Decoration::Offset) { 1125 // If offset info is empty, resize to the number of members; 1126 if (offsetInfo.empty()) { 1127 offsetInfo.resize(memberTypes.size()); 1128 } 1129 offsetInfo[memberIndex] = memberDecoration.second[0]; 1130 } else { 1131 if (!memberDecoration.second.empty()) { 1132 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, 1133 memberDecoration.first, 1134 memberDecoration.second[0]); 1135 } else { 1136 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, 1137 memberDecoration.first, 0); 1138 } 1139 } 1140 } 1141 } 1142 } 1143 } 1144 1145 uint32_t structID = operands[0]; 1146 std::string structIdentifier = nameMap.lookup(structID).str(); 1147 1148 if (structIdentifier.empty()) { 1149 assert(unresolvedMemberTypes.empty() && 1150 "didn't expect unresolved member types"); 1151 typeMap[structID] = 1152 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); 1153 } else { 1154 auto structTy = spirv::StructType::getIdentified(context, structIdentifier); 1155 typeMap[structID] = structTy; 1156 1157 if (!unresolvedMemberTypes.empty()) 1158 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, 1159 memberTypes, offsetInfo, 1160 memberDecorationsInfo}); 1161 else if (failed(structTy.trySetBody(memberTypes, offsetInfo, 1162 memberDecorationsInfo))) 1163 return failure(); 1164 } 1165 1166 // TODO: Update StructType to have member name as attribute as 1167 // well. 1168 return success(); 1169 } 1170 1171 LogicalResult 1172 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { 1173 if (operands.size() != 3) { 1174 // Three operands are needed: result_id, column_type, and column_count 1175 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" 1176 " (result_id, column_type, and column_count)"); 1177 } 1178 // Matrix columns must be of vector type 1179 Type elementTy = getType(operands[1]); 1180 if (!elementTy) { 1181 return emitError(unknownLoc, 1182 "OpTypeMatrix references undefined column type.") 1183 << operands[1]; 1184 } 1185 1186 uint32_t colsCount = operands[2]; 1187 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); 1188 return success(); 1189 } 1190 1191 LogicalResult 1192 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { 1193 if (operands.size() != 2) 1194 return emitError(unknownLoc, 1195 "OpTypeForwardPointer instruction must have two operands"); 1196 1197 typeForwardPointerIDs.insert(operands[0]); 1198 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer 1199 // instruction that defines the actual type. 1200 1201 return success(); 1202 } 1203 1204 LogicalResult 1205 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { 1206 // TODO: Add support for Access Qualifier. 1207 if (operands.size() != 8) 1208 return emitError( 1209 unknownLoc, 1210 "OpTypeImage with non-eight operands are not supported yet"); 1211 1212 Type elementTy = getType(operands[1]); 1213 if (!elementTy) 1214 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ") 1215 << operands[1]; 1216 1217 auto dim = spirv::symbolizeDim(operands[2]); 1218 if (!dim) 1219 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ") 1220 << operands[2]; 1221 1222 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); 1223 if (!depthInfo) 1224 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ") 1225 << operands[3]; 1226 1227 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); 1228 if (!arrayedInfo) 1229 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ") 1230 << operands[4]; 1231 1232 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); 1233 if (!samplingInfo) 1234 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5]; 1235 1236 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); 1237 if (!samplerUseInfo) 1238 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ") 1239 << operands[6]; 1240 1241 auto format = spirv::symbolizeImageFormat(operands[7]); 1242 if (!format) 1243 return emitError(unknownLoc, "unknown Format for OpTypeImage: ") 1244 << operands[7]; 1245 1246 typeMap[operands[0]] = spirv::ImageType::get( 1247 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(), 1248 samplingInfo.value(), samplerUseInfo.value(), format.value()); 1249 return success(); 1250 } 1251 1252 LogicalResult 1253 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { 1254 if (operands.size() != 2) 1255 return emitError(unknownLoc, "OpTypeSampledImage must have two operands"); 1256 1257 Type elementTy = getType(operands[1]); 1258 if (!elementTy) 1259 return emitError(unknownLoc, 1260 "OpTypeSampledImage references undefined <id>: ") 1261 << operands[1]; 1262 1263 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy); 1264 return success(); 1265 } 1266 1267 //===----------------------------------------------------------------------===// 1268 // Constant 1269 //===----------------------------------------------------------------------===// 1270 1271 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, 1272 bool isSpec) { 1273 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; 1274 1275 if (operands.size() < 2) { 1276 return emitError(unknownLoc) 1277 << opname << " must have type <id> and result <id>"; 1278 } 1279 if (operands.size() < 3) { 1280 return emitError(unknownLoc) 1281 << opname << " must have at least 1 more parameter"; 1282 } 1283 1284 Type resultType = getType(operands[0]); 1285 if (!resultType) { 1286 return emitError(unknownLoc, "undefined result type from <id> ") 1287 << operands[0]; 1288 } 1289 1290 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { 1291 if (bitwidth == 64) { 1292 if (operands.size() == 4) { 1293 return success(); 1294 } 1295 return emitError(unknownLoc) 1296 << opname << " should have 2 parameters for 64-bit values"; 1297 } 1298 if (bitwidth <= 32) { 1299 if (operands.size() == 3) { 1300 return success(); 1301 } 1302 1303 return emitError(unknownLoc) 1304 << opname 1305 << " should have 1 parameter for values with no more than 32 bits"; 1306 } 1307 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") 1308 << bitwidth; 1309 }; 1310 1311 auto resultID = operands[1]; 1312 1313 if (auto intType = dyn_cast<IntegerType>(resultType)) { 1314 auto bitwidth = intType.getWidth(); 1315 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1316 return failure(); 1317 } 1318 1319 APInt value; 1320 if (bitwidth == 64) { 1321 // 64-bit integers are represented with two SPIR-V words. According to 1322 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1323 // literal’s low-order words appear first." 1324 struct DoubleWord { 1325 uint32_t word1; 1326 uint32_t word2; 1327 } words = {operands[2], operands[3]}; 1328 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true); 1329 } else if (bitwidth <= 32) { 1330 value = APInt(bitwidth, operands[2], /*isSigned=*/true, 1331 /*implicitTrunc=*/true); 1332 } 1333 1334 auto attr = opBuilder.getIntegerAttr(intType, value); 1335 1336 if (isSpec) { 1337 createSpecConstant(unknownLoc, resultID, attr); 1338 } else { 1339 // For normal constants, we just record the attribute (and its type) for 1340 // later materialization at use sites. 1341 constantMap.try_emplace(resultID, attr, intType); 1342 } 1343 1344 return success(); 1345 } 1346 1347 if (auto floatType = dyn_cast<FloatType>(resultType)) { 1348 auto bitwidth = floatType.getWidth(); 1349 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1350 return failure(); 1351 } 1352 1353 APFloat value(0.f); 1354 if (floatType.isF64()) { 1355 // Double values are represented with two SPIR-V words. According to 1356 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1357 // literal’s low-order words appear first." 1358 struct DoubleWord { 1359 uint32_t word1; 1360 uint32_t word2; 1361 } words = {operands[2], operands[3]}; 1362 value = APFloat(llvm::bit_cast<double>(words)); 1363 } else if (floatType.isF32()) { 1364 value = APFloat(llvm::bit_cast<float>(operands[2])); 1365 } else if (floatType.isF16()) { 1366 APInt data(16, operands[2]); 1367 value = APFloat(APFloat::IEEEhalf(), data); 1368 } 1369 1370 auto attr = opBuilder.getFloatAttr(floatType, value); 1371 if (isSpec) { 1372 createSpecConstant(unknownLoc, resultID, attr); 1373 } else { 1374 // For normal constants, we just record the attribute (and its type) for 1375 // later materialization at use sites. 1376 constantMap.try_emplace(resultID, attr, floatType); 1377 } 1378 1379 return success(); 1380 } 1381 1382 return emitError(unknownLoc, "OpConstant can only generate values of " 1383 "scalar integer or floating-point type"); 1384 } 1385 1386 LogicalResult spirv::Deserializer::processConstantBool( 1387 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { 1388 if (operands.size() != 2) { 1389 return emitError(unknownLoc, "Op") 1390 << (isSpec ? "Spec" : "") << "Constant" 1391 << (isTrue ? "True" : "False") 1392 << " must have type <id> and result <id>"; 1393 } 1394 1395 auto attr = opBuilder.getBoolAttr(isTrue); 1396 auto resultID = operands[1]; 1397 if (isSpec) { 1398 createSpecConstant(unknownLoc, resultID, attr); 1399 } else { 1400 // For normal constants, we just record the attribute (and its type) for 1401 // later materialization at use sites. 1402 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); 1403 } 1404 1405 return success(); 1406 } 1407 1408 LogicalResult 1409 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { 1410 if (operands.size() < 2) { 1411 return emitError(unknownLoc, 1412 "OpConstantComposite must have type <id> and result <id>"); 1413 } 1414 if (operands.size() < 3) { 1415 return emitError(unknownLoc, 1416 "OpConstantComposite must have at least 1 parameter"); 1417 } 1418 1419 Type resultType = getType(operands[0]); 1420 if (!resultType) { 1421 return emitError(unknownLoc, "undefined result type from <id> ") 1422 << operands[0]; 1423 } 1424 1425 SmallVector<Attribute, 4> elements; 1426 elements.reserve(operands.size() - 2); 1427 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1428 auto elementInfo = getConstant(operands[i]); 1429 if (!elementInfo) { 1430 return emitError(unknownLoc, "OpConstantComposite component <id> ") 1431 << operands[i] << " must come from a normal constant"; 1432 } 1433 elements.push_back(elementInfo->first); 1434 } 1435 1436 auto resultID = operands[1]; 1437 if (auto vectorType = dyn_cast<VectorType>(resultType)) { 1438 auto attr = DenseElementsAttr::get(vectorType, elements); 1439 // For normal constants, we just record the attribute (and its type) for 1440 // later materialization at use sites. 1441 constantMap.try_emplace(resultID, attr, resultType); 1442 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) { 1443 auto attr = opBuilder.getArrayAttr(elements); 1444 constantMap.try_emplace(resultID, attr, resultType); 1445 } else { 1446 return emitError(unknownLoc, "unsupported OpConstantComposite type: ") 1447 << resultType; 1448 } 1449 1450 return success(); 1451 } 1452 1453 LogicalResult 1454 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { 1455 if (operands.size() < 2) { 1456 return emitError(unknownLoc, 1457 "OpConstantComposite must have type <id> and result <id>"); 1458 } 1459 if (operands.size() < 3) { 1460 return emitError(unknownLoc, 1461 "OpConstantComposite must have at least 1 parameter"); 1462 } 1463 1464 Type resultType = getType(operands[0]); 1465 if (!resultType) { 1466 return emitError(unknownLoc, "undefined result type from <id> ") 1467 << operands[0]; 1468 } 1469 1470 auto resultID = operands[1]; 1471 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1472 1473 SmallVector<Attribute, 4> elements; 1474 elements.reserve(operands.size() - 2); 1475 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1476 auto elementInfo = getSpecConstant(operands[i]); 1477 elements.push_back(SymbolRefAttr::get(elementInfo)); 1478 } 1479 1480 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( 1481 unknownLoc, TypeAttr::get(resultType), symName, 1482 opBuilder.getArrayAttr(elements)); 1483 specConstCompositeMap[resultID] = op; 1484 1485 return success(); 1486 } 1487 1488 LogicalResult 1489 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { 1490 if (operands.size() < 3) 1491 return emitError(unknownLoc, "OpConstantOperation must have type <id>, " 1492 "result <id>, and operand opcode"); 1493 1494 uint32_t resultTypeID = operands[0]; 1495 1496 if (!getType(resultTypeID)) 1497 return emitError(unknownLoc, "undefined result type from <id> ") 1498 << resultTypeID; 1499 1500 uint32_t resultID = operands[1]; 1501 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); 1502 auto emplaceResult = specConstOperationMap.try_emplace( 1503 resultID, 1504 SpecConstOperationMaterializationInfo{ 1505 enclosedOpcode, resultTypeID, 1506 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); 1507 1508 if (!emplaceResult.second) 1509 return emitError(unknownLoc, "value with <id>: ") 1510 << resultID << " is probably defined before."; 1511 1512 return success(); 1513 } 1514 1515 Value spirv::Deserializer::materializeSpecConstantOperation( 1516 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, 1517 ArrayRef<uint32_t> enclosedOpOperands) { 1518 1519 Type resultType = getType(resultTypeID); 1520 1521 // Instructions wrapped by OpSpecConstantOp need an ID for their 1522 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V 1523 // dialect wrapped op. For that purpose, a new value map is created and "fake" 1524 // ID in that map is assigned to the result of the enclosed instruction. Note 1525 // that there is no need to update this fake ID since we only need to 1526 // reference the created Value for the enclosed op from the spv::YieldOp 1527 // created later in this method (both of which are the only values in their 1528 // region: the SpecConstantOperation's region). If we encounter another 1529 // SpecConstantOperation in the module, we simply re-use the fake ID since the 1530 // previous Value assigned to it isn't visible in the current scope anyway. 1531 DenseMap<uint32_t, Value> newValueMap; 1532 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap); 1533 constexpr uint32_t fakeID = static_cast<uint32_t>(-3); 1534 1535 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; 1536 enclosedOpResultTypeAndOperands.push_back(resultTypeID); 1537 enclosedOpResultTypeAndOperands.push_back(fakeID); 1538 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(), 1539 enclosedOpOperands.end()); 1540 1541 // Process enclosed instruction before creating the enclosing 1542 // specConstantOperation (and its region). This way, references to constants, 1543 // global variables, and spec constants will be materialized outside the new 1544 // op's region. For more info, see Deserializer::getValue's implementation. 1545 if (failed( 1546 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) 1547 return Value(); 1548 1549 // Since the enclosed op is emitted in the current block, split it in a 1550 // separate new block. 1551 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); 1552 1553 auto loc = createFileLineColLoc(opBuilder); 1554 auto specConstOperationOp = 1555 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); 1556 1557 Region &body = specConstOperationOp.getBody(); 1558 // Move the new block into SpecConstantOperation's body. 1559 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), 1560 Region::iterator(enclosedBlock)); 1561 Block &block = body.back(); 1562 1563 // RAII guard to reset the insertion point to the module's region after 1564 // deserializing the body of the specConstantOperation. 1565 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 1566 opBuilder.setInsertionPointToEnd(&block); 1567 1568 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); 1569 return specConstOperationOp.getResult(); 1570 } 1571 1572 LogicalResult 1573 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { 1574 if (operands.size() != 2) { 1575 return emitError(unknownLoc, 1576 "OpConstantNull must have type <id> and result <id>"); 1577 } 1578 1579 Type resultType = getType(operands[0]); 1580 if (!resultType) { 1581 return emitError(unknownLoc, "undefined result type from <id> ") 1582 << operands[0]; 1583 } 1584 1585 auto resultID = operands[1]; 1586 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) { 1587 auto attr = opBuilder.getZeroAttr(resultType); 1588 // For normal constants, we just record the attribute (and its type) for 1589 // later materialization at use sites. 1590 constantMap.try_emplace(resultID, attr, resultType); 1591 return success(); 1592 } 1593 1594 return emitError(unknownLoc, "unsupported OpConstantNull type: ") 1595 << resultType; 1596 } 1597 1598 //===----------------------------------------------------------------------===// 1599 // Control flow 1600 //===----------------------------------------------------------------------===// 1601 1602 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { 1603 if (auto *block = getBlock(id)) { 1604 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id 1605 << " @ " << block << "\n"); 1606 return block; 1607 } 1608 1609 // We don't know where this block will be placed finally (in a 1610 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the 1611 // function for now and sort out the proper place later. 1612 auto *block = curFunction->addBlock(); 1613 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id 1614 << " @ " << block << "\n"); 1615 return blockMap[id] = block; 1616 } 1617 1618 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { 1619 if (!curBlock) { 1620 return emitError(unknownLoc, "OpBranch must appear inside a block"); 1621 } 1622 1623 if (operands.size() != 1) { 1624 return emitError(unknownLoc, "OpBranch must take exactly one target label"); 1625 } 1626 1627 auto *target = getOrCreateBlock(operands[0]); 1628 auto loc = createFileLineColLoc(opBuilder); 1629 // The preceding instruction for the OpBranch instruction could be an 1630 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have 1631 // the same OpLine information. 1632 opBuilder.create<spirv::BranchOp>(loc, target); 1633 1634 clearDebugLine(); 1635 return success(); 1636 } 1637 1638 LogicalResult 1639 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { 1640 if (!curBlock) { 1641 return emitError(unknownLoc, 1642 "OpBranchConditional must appear inside a block"); 1643 } 1644 1645 if (operands.size() != 3 && operands.size() != 5) { 1646 return emitError(unknownLoc, 1647 "OpBranchConditional must have condition, true label, " 1648 "false label, and optionally two branch weights"); 1649 } 1650 1651 auto condition = getValue(operands[0]); 1652 auto *trueBlock = getOrCreateBlock(operands[1]); 1653 auto *falseBlock = getOrCreateBlock(operands[2]); 1654 1655 std::optional<std::pair<uint32_t, uint32_t>> weights; 1656 if (operands.size() == 5) { 1657 weights = std::make_pair(operands[3], operands[4]); 1658 } 1659 // The preceding instruction for the OpBranchConditional instruction could be 1660 // an OpSelectionMerge instruction, in this case they will have the same 1661 // OpLine information. 1662 auto loc = createFileLineColLoc(opBuilder); 1663 opBuilder.create<spirv::BranchConditionalOp>( 1664 loc, condition, trueBlock, 1665 /*trueArguments=*/ArrayRef<Value>(), falseBlock, 1666 /*falseArguments=*/ArrayRef<Value>(), weights); 1667 1668 clearDebugLine(); 1669 return success(); 1670 } 1671 1672 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { 1673 if (!curFunction) { 1674 return emitError(unknownLoc, "OpLabel must appear inside a function"); 1675 } 1676 1677 if (operands.size() != 1) { 1678 return emitError(unknownLoc, "OpLabel should only have result <id>"); 1679 } 1680 1681 auto labelID = operands[0]; 1682 // We may have forward declared this block. 1683 auto *block = getOrCreateBlock(labelID); 1684 LLVM_DEBUG(logger.startLine() 1685 << "[block] populating block " << block << "\n"); 1686 // If we have seen this block, make sure it was just a forward declaration. 1687 assert(block->empty() && "re-deserialize the same block!"); 1688 1689 opBuilder.setInsertionPointToStart(block); 1690 blockMap[labelID] = curBlock = block; 1691 1692 return success(); 1693 } 1694 1695 LogicalResult 1696 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { 1697 if (!curBlock) { 1698 return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); 1699 } 1700 1701 if (operands.size() < 2) { 1702 return emitError( 1703 unknownLoc, 1704 "OpSelectionMerge must specify merge target and selection control"); 1705 } 1706 1707 auto *mergeBlock = getOrCreateBlock(operands[0]); 1708 auto loc = createFileLineColLoc(opBuilder); 1709 auto selectionControl = operands[1]; 1710 1711 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) 1712 .second) { 1713 return emitError( 1714 unknownLoc, 1715 "a block cannot have more than one OpSelectionMerge instruction"); 1716 } 1717 1718 return success(); 1719 } 1720 1721 LogicalResult 1722 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { 1723 if (!curBlock) { 1724 return emitError(unknownLoc, "OpLoopMerge must appear in a block"); 1725 } 1726 1727 if (operands.size() < 3) { 1728 return emitError(unknownLoc, "OpLoopMerge must specify merge target, " 1729 "continue target and loop control"); 1730 } 1731 1732 auto *mergeBlock = getOrCreateBlock(operands[0]); 1733 auto *continueBlock = getOrCreateBlock(operands[1]); 1734 auto loc = createFileLineColLoc(opBuilder); 1735 uint32_t loopControl = operands[2]; 1736 1737 if (!blockMergeInfo 1738 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) 1739 .second) { 1740 return emitError( 1741 unknownLoc, 1742 "a block cannot have more than one OpLoopMerge instruction"); 1743 } 1744 1745 return success(); 1746 } 1747 1748 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { 1749 if (!curBlock) { 1750 return emitError(unknownLoc, "OpPhi must appear in a block"); 1751 } 1752 1753 if (operands.size() < 4) { 1754 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, " 1755 "and variable-parent pairs"); 1756 } 1757 1758 // Create a block argument for this OpPhi instruction. 1759 Type blockArgType = getType(operands[0]); 1760 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc); 1761 valueMap[operands[1]] = blockArg; 1762 LLVM_DEBUG(logger.startLine() 1763 << "[phi] created block argument " << blockArg 1764 << " id = " << operands[1] << " of type " << blockArgType << "\n"); 1765 1766 // For each (value, predecessor) pair, insert the value to the predecessor's 1767 // blockPhiInfo entry so later we can fix the block argument there. 1768 for (unsigned i = 2, e = operands.size(); i < e; i += 2) { 1769 uint32_t value = operands[i]; 1770 Block *predecessor = getOrCreateBlock(operands[i + 1]); 1771 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock}; 1772 blockPhiInfo[predecessorTargetPair].push_back(value); 1773 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor 1774 << " with arg id = " << value << "\n"); 1775 } 1776 1777 return success(); 1778 } 1779 1780 namespace { 1781 /// A class for putting all blocks in a structured selection/loop in a 1782 /// spirv.mlir.selection/spirv.mlir.loop op. 1783 class ControlFlowStructurizer { 1784 public: 1785 #ifndef NDEBUG 1786 ControlFlowStructurizer(Location loc, uint32_t control, 1787 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1788 Block *merge, Block *cont, 1789 llvm::ScopedPrinter &logger) 1790 : location(loc), control(control), blockMergeInfo(mergeInfo), 1791 headerBlock(header), mergeBlock(merge), continueBlock(cont), 1792 logger(logger) {} 1793 #else 1794 ControlFlowStructurizer(Location loc, uint32_t control, 1795 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1796 Block *merge, Block *cont) 1797 : location(loc), control(control), blockMergeInfo(mergeInfo), 1798 headerBlock(header), mergeBlock(merge), continueBlock(cont) {} 1799 #endif 1800 1801 /// Structurizes the loop at the given `headerBlock`. 1802 /// 1803 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move 1804 /// all blocks in the structured loop into the spirv.mlir.loop's region. All 1805 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This 1806 /// method will also update `mergeInfo` by remapping all blocks inside to the 1807 /// newly cloned ones inside structured control flow op's regions. 1808 LogicalResult structurize(); 1809 1810 private: 1811 /// Creates a new spirv.mlir.selection op at the beginning of the 1812 /// `mergeBlock`. 1813 spirv::SelectionOp createSelectionOp(uint32_t selectionControl); 1814 1815 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`. 1816 spirv::LoopOp createLoopOp(uint32_t loopControl); 1817 1818 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. 1819 void collectBlocksInConstruct(); 1820 1821 Location location; 1822 uint32_t control; 1823 1824 spirv::BlockMergeInfoMap &blockMergeInfo; 1825 1826 Block *headerBlock; 1827 Block *mergeBlock; 1828 Block *continueBlock; // nullptr for spirv.mlir.selection 1829 1830 SetVector<Block *> constructBlocks; 1831 1832 #ifndef NDEBUG 1833 /// A logger used to emit information during the deserialzation process. 1834 llvm::ScopedPrinter &logger; 1835 #endif 1836 }; 1837 } // namespace 1838 1839 spirv::SelectionOp 1840 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { 1841 // Create a builder and set the insertion point to the beginning of the 1842 // merge block so that the newly created SelectionOp will be inserted there. 1843 OpBuilder builder(&mergeBlock->front()); 1844 1845 auto control = static_cast<spirv::SelectionControl>(selectionControl); 1846 auto selectionOp = builder.create<spirv::SelectionOp>(location, control); 1847 selectionOp.addMergeBlock(builder); 1848 1849 return selectionOp; 1850 } 1851 1852 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { 1853 // Create a builder and set the insertion point to the beginning of the 1854 // merge block so that the newly created LoopOp will be inserted there. 1855 OpBuilder builder(&mergeBlock->front()); 1856 1857 auto control = static_cast<spirv::LoopControl>(loopControl); 1858 auto loopOp = builder.create<spirv::LoopOp>(location, control); 1859 loopOp.addEntryAndMergeBlock(builder); 1860 1861 return loopOp; 1862 } 1863 1864 void ControlFlowStructurizer::collectBlocksInConstruct() { 1865 assert(constructBlocks.empty() && "expected empty constructBlocks"); 1866 1867 // Put the header block in the work list first. 1868 constructBlocks.insert(headerBlock); 1869 1870 // For each item in the work list, add its successors excluding the merge 1871 // block. 1872 for (unsigned i = 0; i < constructBlocks.size(); ++i) { 1873 for (auto *successor : constructBlocks[i]->getSuccessors()) 1874 if (successor != mergeBlock) 1875 constructBlocks.insert(successor); 1876 } 1877 } 1878 1879 LogicalResult ControlFlowStructurizer::structurize() { 1880 Operation *op = nullptr; 1881 bool isLoop = continueBlock != nullptr; 1882 if (isLoop) { 1883 if (auto loopOp = createLoopOp(control)) 1884 op = loopOp.getOperation(); 1885 } else { 1886 if (auto selectionOp = createSelectionOp(control)) 1887 op = selectionOp.getOperation(); 1888 } 1889 if (!op) 1890 return failure(); 1891 Region &body = op->getRegion(0); 1892 1893 IRMapping mapper; 1894 // All references to the old merge block should be directed to the 1895 // selection/loop merge block in the SelectionOp/LoopOp's region. 1896 mapper.map(mergeBlock, &body.back()); 1897 1898 collectBlocksInConstruct(); 1899 1900 // We've identified all blocks belonging to the selection/loop's region. Now 1901 // need to "move" them into the selection/loop. Instead of really moving the 1902 // blocks, in the following we copy them and remap all values and branches. 1903 // This is because: 1904 // * Inserting a block into a region requires the block not in any region 1905 // before. But selections/loops can nest so we can create selection/loop ops 1906 // in a nested manner, which means some blocks may already be in a 1907 // selection/loop region when to be moved again. 1908 // * It's much trickier to fix up the branches into and out of the loop's 1909 // region: we need to treat not-moved blocks and moved blocks differently: 1910 // Not-moved blocks jumping to the loop header block need to jump to the 1911 // merge point containing the new loop op but not the loop continue block's 1912 // back edge. Moved blocks jumping out of the loop need to jump to the 1913 // merge block inside the loop region but not other not-moved blocks. 1914 // We cannot use replaceAllUsesWith clearly and it's harder to follow the 1915 // logic. 1916 1917 // Create a corresponding block in the SelectionOp/LoopOp's region for each 1918 // block in this loop construct. 1919 OpBuilder builder(body); 1920 for (auto *block : constructBlocks) { 1921 // Create a block and insert it before the selection/loop merge block in the 1922 // SelectionOp/LoopOp's region. 1923 auto *newBlock = builder.createBlock(&body.back()); 1924 mapper.map(block, newBlock); 1925 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock 1926 << " from block " << block << "\n"); 1927 if (!isFnEntryBlock(block)) { 1928 for (BlockArgument blockArg : block->getArguments()) { 1929 auto newArg = 1930 newBlock->addArgument(blockArg.getType(), blockArg.getLoc()); 1931 mapper.map(blockArg, newArg); 1932 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument " 1933 << blockArg << " to " << newArg << "\n"); 1934 } 1935 } else { 1936 LLVM_DEBUG(logger.startLine() 1937 << "[cf] block " << block << " is a function entry block\n"); 1938 } 1939 1940 for (auto &op : *block) 1941 newBlock->push_back(op.clone(mapper)); 1942 } 1943 1944 // Go through all ops and remap the operands. 1945 auto remapOperands = [&](Operation *op) { 1946 for (auto &operand : op->getOpOperands()) 1947 if (Value mappedOp = mapper.lookupOrNull(operand.get())) 1948 operand.set(mappedOp); 1949 for (auto &succOp : op->getBlockOperands()) 1950 if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) 1951 succOp.set(mappedOp); 1952 }; 1953 for (auto &block : body) 1954 block.walk(remapOperands); 1955 1956 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to 1957 // the selection/loop construct into its region. Next we need to fix the 1958 // connections between this new SelectionOp/LoopOp with existing blocks. 1959 1960 // All existing incoming branches should go to the merge block, where the 1961 // SelectionOp/LoopOp resides right now. 1962 headerBlock->replaceAllUsesWith(mergeBlock); 1963 1964 LLVM_DEBUG({ 1965 logger.startLine() << "[cf] after cloning and fixing references:\n"; 1966 headerBlock->getParentOp()->print(logger.getOStream()); 1967 logger.startLine() << "\n"; 1968 }); 1969 1970 if (isLoop) { 1971 if (!mergeBlock->args_empty()) { 1972 return mergeBlock->getParentOp()->emitError( 1973 "OpPhi in loop merge block unsupported"); 1974 } 1975 1976 // The loop header block may have block arguments. Since now we place the 1977 // loop op inside the old merge block, we need to make sure the old merge 1978 // block has the same block argument list. 1979 for (BlockArgument blockArg : headerBlock->getArguments()) 1980 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc()); 1981 1982 // If the loop header block has block arguments, make sure the spirv.Branch 1983 // op matches. 1984 SmallVector<Value, 4> blockArgs; 1985 if (!headerBlock->args_empty()) 1986 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; 1987 1988 // The loop entry block should have a unconditional branch jumping to the 1989 // loop header block. 1990 builder.setInsertionPointToEnd(&body.front()); 1991 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), 1992 ArrayRef<Value>(blockArgs)); 1993 } 1994 1995 // All the blocks cloned into the SelectionOp/LoopOp's region can now be 1996 // cleaned up. 1997 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n"); 1998 // First we need to drop all operands' references inside all blocks. This is 1999 // needed because we can have blocks referencing SSA values from one another. 2000 for (auto *block : constructBlocks) 2001 block->dropAllReferences(); 2002 2003 // Check that whether some op in the to-be-erased blocks still has uses. Those 2004 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's 2005 // region. We cannot handle such cases given that once a value is sinked into 2006 // the SelectionOp/LoopOp's region, there is no escape for it: 2007 // SelectionOp/LooOp does not support yield values right now. 2008 for (auto *block : constructBlocks) { 2009 for (Operation &op : *block) 2010 if (!op.use_empty()) 2011 return op.emitOpError( 2012 "failed control flow structurization: it has uses outside of the " 2013 "enclosing selection/loop construct"); 2014 } 2015 2016 // Then erase all old blocks. 2017 for (auto *block : constructBlocks) { 2018 // We've cloned all blocks belonging to this construct into the structured 2019 // control flow op's region. Among these blocks, some may compose another 2020 // selection/loop. If so, they will be recorded within blockMergeInfo. 2021 // We need to update the pointers there to the newly remapped ones so we can 2022 // continue structurizing them later. 2023 // TODO: The asserts in the following assumes input SPIR-V blob forms 2024 // correctly nested selection/loop constructs. We should relax this and 2025 // support error cases better. 2026 auto it = blockMergeInfo.find(block); 2027 if (it != blockMergeInfo.end()) { 2028 // Use the original location for nested selection/loop ops. 2029 Location loc = it->second.loc; 2030 2031 Block *newHeader = mapper.lookupOrNull(block); 2032 if (!newHeader) 2033 return emitError(loc, "failed control flow structurization: nested " 2034 "loop header block should be remapped!"); 2035 2036 Block *newContinue = it->second.continueBlock; 2037 if (newContinue) { 2038 newContinue = mapper.lookupOrNull(newContinue); 2039 if (!newContinue) 2040 return emitError(loc, "failed control flow structurization: nested " 2041 "loop continue block should be remapped!"); 2042 } 2043 2044 Block *newMerge = it->second.mergeBlock; 2045 if (Block *mappedTo = mapper.lookupOrNull(newMerge)) 2046 newMerge = mappedTo; 2047 2048 // The iterator should be erased before adding a new entry into 2049 // blockMergeInfo to avoid iterator invalidation. 2050 blockMergeInfo.erase(it); 2051 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, 2052 newContinue); 2053 } 2054 2055 // The structured selection/loop's entry block does not have arguments. 2056 // If the function's header block is also part of the structured control 2057 // flow, we cannot just simply erase it because it may contain arguments 2058 // matching the function signature and used by the cloned blocks. 2059 if (isFnEntryBlock(block)) { 2060 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block 2061 << " to only contain a spirv.Branch op\n"); 2062 // Still keep the function entry block for the potential block arguments, 2063 // but replace all ops inside with a branch to the merge block. 2064 block->clear(); 2065 builder.setInsertionPointToEnd(block); 2066 builder.create<spirv::BranchOp>(location, mergeBlock); 2067 } else { 2068 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n"); 2069 block->erase(); 2070 } 2071 } 2072 2073 LLVM_DEBUG(logger.startLine() 2074 << "[cf] after structurizing construct with header block " 2075 << headerBlock << ":\n" 2076 << *op << "\n"); 2077 2078 return success(); 2079 } 2080 2081 LogicalResult spirv::Deserializer::wireUpBlockArgument() { 2082 LLVM_DEBUG({ 2083 logger.startLine() 2084 << "//----- [phi] start wiring up block arguments -----//\n"; 2085 logger.indent(); 2086 }); 2087 2088 OpBuilder::InsertionGuard guard(opBuilder); 2089 2090 for (const auto &info : blockPhiInfo) { 2091 Block *block = info.first.first; 2092 Block *target = info.first.second; 2093 const BlockPhiInfo &phiInfo = info.second; 2094 LLVM_DEBUG({ 2095 logger.startLine() << "[phi] block " << block << "\n"; 2096 logger.startLine() << "[phi] before creating block argument:\n"; 2097 block->getParentOp()->print(logger.getOStream()); 2098 logger.startLine() << "\n"; 2099 }); 2100 2101 // Set insertion point to before this block's terminator early because we 2102 // may materialize ops via getValue() call. 2103 auto *op = block->getTerminator(); 2104 opBuilder.setInsertionPoint(op); 2105 2106 SmallVector<Value, 4> blockArgs; 2107 blockArgs.reserve(phiInfo.size()); 2108 for (uint32_t valueId : phiInfo) { 2109 if (Value value = getValue(valueId)) { 2110 blockArgs.push_back(value); 2111 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value 2112 << " id = " << valueId << "\n"); 2113 } else { 2114 return emitError(unknownLoc, "OpPhi references undefined value!"); 2115 } 2116 } 2117 2118 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { 2119 // Replace the previous branch op with a new one with block arguments. 2120 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), 2121 blockArgs); 2122 branchOp.erase(); 2123 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) { 2124 assert((branchCondOp.getTrueBlock() == target || 2125 branchCondOp.getFalseBlock() == target) && 2126 "expected target to be either the true or false target"); 2127 if (target == branchCondOp.getTrueTarget()) 2128 opBuilder.create<spirv::BranchConditionalOp>( 2129 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs, 2130 branchCondOp.getFalseBlockArguments(), 2131 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(), 2132 branchCondOp.getFalseTarget()); 2133 else 2134 opBuilder.create<spirv::BranchConditionalOp>( 2135 branchCondOp.getLoc(), branchCondOp.getCondition(), 2136 branchCondOp.getTrueBlockArguments(), blockArgs, 2137 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(), 2138 branchCondOp.getFalseBlock()); 2139 2140 branchCondOp.erase(); 2141 } else { 2142 return emitError(unknownLoc, "unimplemented terminator for Phi creation"); 2143 } 2144 2145 LLVM_DEBUG({ 2146 logger.startLine() << "[phi] after creating block argument:\n"; 2147 block->getParentOp()->print(logger.getOStream()); 2148 logger.startLine() << "\n"; 2149 }); 2150 } 2151 blockPhiInfo.clear(); 2152 2153 LLVM_DEBUG({ 2154 logger.unindent(); 2155 logger.startLine() 2156 << "//--- [phi] completed wiring up block arguments ---//\n"; 2157 }); 2158 return success(); 2159 } 2160 2161 LogicalResult spirv::Deserializer::structurizeControlFlow() { 2162 LLVM_DEBUG({ 2163 logger.startLine() 2164 << "//----- [cf] start structurizing control flow -----//\n"; 2165 logger.indent(); 2166 }); 2167 2168 while (!blockMergeInfo.empty()) { 2169 Block *headerBlock = blockMergeInfo.begin()->first; 2170 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; 2171 2172 LLVM_DEBUG({ 2173 logger.startLine() << "[cf] header block " << headerBlock << ":\n"; 2174 headerBlock->print(logger.getOStream()); 2175 logger.startLine() << "\n"; 2176 }); 2177 2178 auto *mergeBlock = mergeInfo.mergeBlock; 2179 assert(mergeBlock && "merge block cannot be nullptr"); 2180 if (!mergeBlock->args_empty()) 2181 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); 2182 LLVM_DEBUG({ 2183 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n"; 2184 mergeBlock->print(logger.getOStream()); 2185 logger.startLine() << "\n"; 2186 }); 2187 2188 auto *continueBlock = mergeInfo.continueBlock; 2189 LLVM_DEBUG(if (continueBlock) { 2190 logger.startLine() << "[cf] continue block " << continueBlock << ":\n"; 2191 continueBlock->print(logger.getOStream()); 2192 logger.startLine() << "\n"; 2193 }); 2194 // Erase this case before calling into structurizer, who will update 2195 // blockMergeInfo. 2196 blockMergeInfo.erase(blockMergeInfo.begin()); 2197 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control, 2198 blockMergeInfo, headerBlock, 2199 mergeBlock, continueBlock 2200 #ifndef NDEBUG 2201 , 2202 logger 2203 #endif 2204 ); 2205 if (failed(structurizer.structurize())) 2206 return failure(); 2207 } 2208 2209 LLVM_DEBUG({ 2210 logger.unindent(); 2211 logger.startLine() 2212 << "//--- [cf] completed structurizing control flow ---//\n"; 2213 }); 2214 return success(); 2215 } 2216 2217 //===----------------------------------------------------------------------===// 2218 // Debug 2219 //===----------------------------------------------------------------------===// 2220 2221 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { 2222 if (!debugLine) 2223 return unknownLoc; 2224 2225 auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); 2226 if (fileName.empty()) 2227 fileName = "<unknown>"; 2228 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line, 2229 debugLine->column); 2230 } 2231 2232 LogicalResult 2233 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { 2234 // According to SPIR-V spec: 2235 // "This location information applies to the instructions physically 2236 // following this instruction, up to the first occurrence of any of the 2237 // following: the next end of block, the next OpLine instruction, or the next 2238 // OpNoLine instruction." 2239 if (operands.size() != 3) 2240 return emitError(unknownLoc, "OpLine must have 3 operands"); 2241 debugLine = DebugLine{operands[0], operands[1], operands[2]}; 2242 return success(); 2243 } 2244 2245 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; } 2246 2247 LogicalResult 2248 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { 2249 if (operands.size() < 2) 2250 return emitError(unknownLoc, "OpString needs at least 2 operands"); 2251 2252 if (!debugInfoMap.lookup(operands[0]).empty()) 2253 return emitError(unknownLoc, 2254 "duplicate debug string found for result <id> ") 2255 << operands[0]; 2256 2257 unsigned wordIndex = 1; 2258 StringRef debugString = decodeStringLiteral(operands, wordIndex); 2259 if (wordIndex != operands.size()) 2260 return emitError(unknownLoc, 2261 "unexpected trailing words in OpString instruction"); 2262 2263 debugInfoMap[operands[0]] = debugString; 2264 return success(); 2265 } 2266