1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Deserializer methods for SPIR-V binary instructions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/Location.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/Support/Debug.h" 23 #include <optional> 24 25 using namespace mlir; 26 27 #define DEBUG_TYPE "spirv-deserialization" 28 29 //===----------------------------------------------------------------------===// 30 // Utility Functions 31 //===----------------------------------------------------------------------===// 32 33 /// Extracts the opcode from the given first word of a SPIR-V instruction. 34 static inline spirv::Opcode extractOpcode(uint32_t word) { 35 return static_cast<spirv::Opcode>(word & 0xffff); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Instruction 40 //===----------------------------------------------------------------------===// 41 42 Value spirv::Deserializer::getValue(uint32_t id) { 43 if (auto constInfo = getConstant(id)) { 44 // Materialize a `spirv.Constant` op at every use site. 45 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second, 46 constInfo->first); 47 } 48 if (auto varOp = getGlobalVariable(id)) { 49 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( 50 unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation())); 51 return addressOfOp.getPointer(); 52 } 53 if (auto constOp = getSpecConstant(id)) { 54 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 55 unknownLoc, constOp.getDefaultValue().getType(), 56 SymbolRefAttr::get(constOp.getOperation())); 57 return referenceOfOp.getReference(); 58 } 59 if (auto constCompositeOp = getSpecConstantComposite(id)) { 60 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 61 unknownLoc, constCompositeOp.getType(), 62 SymbolRefAttr::get(constCompositeOp.getOperation())); 63 return referenceOfOp.getReference(); 64 } 65 if (auto specConstOperationInfo = getSpecConstantOperation(id)) { 66 return materializeSpecConstantOperation( 67 id, specConstOperationInfo->enclodesOpcode, 68 specConstOperationInfo->resultTypeID, 69 specConstOperationInfo->enclosedOpOperands); 70 } 71 if (auto undef = getUndefType(id)) { 72 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef); 73 } 74 return valueMap.lookup(id); 75 } 76 77 LogicalResult spirv::Deserializer::sliceInstruction( 78 spirv::Opcode &opcode, ArrayRef<uint32_t> &operands, 79 std::optional<spirv::Opcode> expectedOpcode) { 80 auto binarySize = binary.size(); 81 if (curOffset >= binarySize) { 82 return emitError(unknownLoc, "expected ") 83 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) 84 : "more") 85 << " instruction"; 86 } 87 88 // For each instruction, get its word count from the first word to slice it 89 // from the stream properly, and then dispatch to the instruction handler. 90 91 uint32_t wordCount = binary[curOffset] >> 16; 92 93 if (wordCount == 0) 94 return emitError(unknownLoc, "word count cannot be zero"); 95 96 uint32_t nextOffset = curOffset + wordCount; 97 if (nextOffset > binarySize) 98 return emitError(unknownLoc, "insufficient words for the last instruction"); 99 100 opcode = extractOpcode(binary[curOffset]); 101 operands = binary.slice(curOffset + 1, wordCount - 1); 102 curOffset = nextOffset; 103 return success(); 104 } 105 106 LogicalResult spirv::Deserializer::processInstruction( 107 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) { 108 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction " 109 << spirv::stringifyOpcode(opcode) << "\n"); 110 111 // First dispatch all the instructions whose opcode does not correspond to 112 // those that have a direct mirror in the SPIR-V dialect 113 switch (opcode) { 114 case spirv::Opcode::OpCapability: 115 return processCapability(operands); 116 case spirv::Opcode::OpExtension: 117 return processExtension(operands); 118 case spirv::Opcode::OpExtInst: 119 return processExtInst(operands); 120 case spirv::Opcode::OpExtInstImport: 121 return processExtInstImport(operands); 122 case spirv::Opcode::OpMemberName: 123 return processMemberName(operands); 124 case spirv::Opcode::OpMemoryModel: 125 return processMemoryModel(operands); 126 case spirv::Opcode::OpEntryPoint: 127 case spirv::Opcode::OpExecutionMode: 128 if (deferInstructions) { 129 deferredInstructions.emplace_back(opcode, operands); 130 return success(); 131 } 132 break; 133 case spirv::Opcode::OpVariable: 134 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { 135 return processGlobalVariable(operands); 136 } 137 break; 138 case spirv::Opcode::OpLine: 139 return processDebugLine(operands); 140 case spirv::Opcode::OpNoLine: 141 clearDebugLine(); 142 return success(); 143 case spirv::Opcode::OpName: 144 return processName(operands); 145 case spirv::Opcode::OpString: 146 return processDebugString(operands); 147 case spirv::Opcode::OpModuleProcessed: 148 case spirv::Opcode::OpSource: 149 case spirv::Opcode::OpSourceContinued: 150 case spirv::Opcode::OpSourceExtension: 151 // TODO: This is debug information embedded in the binary which should be 152 // translated into the spirv.module. 153 return success(); 154 case spirv::Opcode::OpTypeVoid: 155 case spirv::Opcode::OpTypeBool: 156 case spirv::Opcode::OpTypeInt: 157 case spirv::Opcode::OpTypeFloat: 158 case spirv::Opcode::OpTypeVector: 159 case spirv::Opcode::OpTypeMatrix: 160 case spirv::Opcode::OpTypeArray: 161 case spirv::Opcode::OpTypeFunction: 162 case spirv::Opcode::OpTypeImage: 163 case spirv::Opcode::OpTypeSampledImage: 164 case spirv::Opcode::OpTypeRuntimeArray: 165 case spirv::Opcode::OpTypeStruct: 166 case spirv::Opcode::OpTypePointer: 167 case spirv::Opcode::OpTypeCooperativeMatrixKHR: 168 return processType(opcode, operands); 169 case spirv::Opcode::OpTypeForwardPointer: 170 return processTypeForwardPointer(operands); 171 case spirv::Opcode::OpConstant: 172 return processConstant(operands, /*isSpec=*/false); 173 case spirv::Opcode::OpSpecConstant: 174 return processConstant(operands, /*isSpec=*/true); 175 case spirv::Opcode::OpConstantComposite: 176 return processConstantComposite(operands); 177 case spirv::Opcode::OpSpecConstantComposite: 178 return processSpecConstantComposite(operands); 179 case spirv::Opcode::OpSpecConstantOp: 180 return processSpecConstantOperation(operands); 181 case spirv::Opcode::OpConstantTrue: 182 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); 183 case spirv::Opcode::OpSpecConstantTrue: 184 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); 185 case spirv::Opcode::OpConstantFalse: 186 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); 187 case spirv::Opcode::OpSpecConstantFalse: 188 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); 189 case spirv::Opcode::OpConstantNull: 190 return processConstantNull(operands); 191 case spirv::Opcode::OpDecorate: 192 return processDecoration(operands); 193 case spirv::Opcode::OpMemberDecorate: 194 return processMemberDecoration(operands); 195 case spirv::Opcode::OpFunction: 196 return processFunction(operands); 197 case spirv::Opcode::OpLabel: 198 return processLabel(operands); 199 case spirv::Opcode::OpBranch: 200 return processBranch(operands); 201 case spirv::Opcode::OpBranchConditional: 202 return processBranchConditional(operands); 203 case spirv::Opcode::OpSelectionMerge: 204 return processSelectionMerge(operands); 205 case spirv::Opcode::OpLoopMerge: 206 return processLoopMerge(operands); 207 case spirv::Opcode::OpPhi: 208 return processPhi(operands); 209 case spirv::Opcode::OpUndef: 210 return processUndef(operands); 211 default: 212 break; 213 } 214 return dispatchToAutogenDeserialization(opcode, operands); 215 } 216 217 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( 218 ArrayRef<uint32_t> words, StringRef opName, bool hasResult, 219 unsigned numOperands) { 220 SmallVector<Type, 1> resultTypes; 221 uint32_t valueID = 0; 222 223 size_t wordIndex = 0; 224 if (hasResult) { 225 if (wordIndex >= words.size()) 226 return emitError(unknownLoc, 227 "expected result type <id> while deserializing for ") 228 << opName; 229 230 // Decode the type <id> 231 auto type = getType(words[wordIndex]); 232 if (!type) 233 return emitError(unknownLoc, "unknown type result <id>: ") 234 << words[wordIndex]; 235 resultTypes.push_back(type); 236 ++wordIndex; 237 238 // Decode the result <id> 239 if (wordIndex >= words.size()) 240 return emitError(unknownLoc, 241 "expected result <id> while deserializing for ") 242 << opName; 243 valueID = words[wordIndex]; 244 ++wordIndex; 245 } 246 247 SmallVector<Value, 4> operands; 248 SmallVector<NamedAttribute, 4> attributes; 249 250 // Decode operands 251 size_t operandIndex = 0; 252 for (; operandIndex < numOperands && wordIndex < words.size(); 253 ++operandIndex, ++wordIndex) { 254 auto arg = getValue(words[wordIndex]); 255 if (!arg) 256 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex]; 257 operands.push_back(arg); 258 } 259 if (operandIndex != numOperands) { 260 return emitError( 261 unknownLoc, 262 "found less operands than expected when deserializing for ") 263 << opName << "; only " << operandIndex << " of " << numOperands 264 << " processed"; 265 } 266 if (wordIndex != words.size()) { 267 return emitError( 268 unknownLoc, 269 "found more operands than expected when deserializing for ") 270 << opName << "; only " << wordIndex << " of " << words.size() 271 << " processed"; 272 } 273 274 // Attach attributes from decorations 275 if (decorations.count(valueID)) { 276 auto attrs = decorations[valueID].getAttrs(); 277 attributes.append(attrs.begin(), attrs.end()); 278 } 279 280 // Create the op and update bookkeeping maps 281 Location loc = createFileLineColLoc(opBuilder); 282 OperationState opState(loc, opName); 283 opState.addOperands(operands); 284 if (hasResult) 285 opState.addTypes(resultTypes); 286 opState.addAttributes(attributes); 287 Operation *op = opBuilder.create(opState); 288 if (hasResult) 289 valueMap[valueID] = op->getResult(0); 290 291 if (op->hasTrait<OpTrait::IsTerminator>()) 292 clearDebugLine(); 293 294 return success(); 295 } 296 297 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) { 298 if (operands.size() != 2) { 299 return emitError(unknownLoc, "OpUndef instruction must have two operands"); 300 } 301 auto type = getType(operands[0]); 302 if (!type) { 303 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction"); 304 } 305 undefMap[operands[1]] = type; 306 return success(); 307 } 308 309 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) { 310 if (operands.size() < 4) { 311 return emitError(unknownLoc, 312 "OpExtInst must have at least 4 operands, result type " 313 "<id>, result <id>, set <id> and instruction opcode"); 314 } 315 if (!extendedInstSets.count(operands[2])) { 316 return emitError(unknownLoc, "undefined set <id> in OpExtInst"); 317 } 318 SmallVector<uint32_t, 4> slicedOperands; 319 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); 320 slicedOperands.append(std::next(operands.begin(), 4), operands.end()); 321 return dispatchToExtensionSetAutogenDeserialization( 322 extendedInstSets[operands[2]], operands[3], slicedOperands); 323 } 324 325 namespace mlir { 326 namespace spirv { 327 328 template <> 329 LogicalResult 330 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { 331 unsigned wordIndex = 0; 332 if (wordIndex >= words.size()) { 333 return emitError(unknownLoc, 334 "missing Execution Model specification in OpEntryPoint"); 335 } 336 auto execModel = spirv::ExecutionModelAttr::get( 337 context, static_cast<spirv::ExecutionModel>(words[wordIndex++])); 338 if (wordIndex >= words.size()) { 339 return emitError(unknownLoc, "missing <id> in OpEntryPoint"); 340 } 341 // Get the function <id> 342 auto fnID = words[wordIndex++]; 343 // Get the function name 344 auto fnName = decodeStringLiteral(words, wordIndex); 345 // Verify that the function <id> matches the fnName 346 auto parsedFunc = getFunction(fnID); 347 if (!parsedFunc) { 348 return emitError(unknownLoc, "no function matching <id> ") << fnID; 349 } 350 if (parsedFunc.getName() != fnName) { 351 // The deserializer uses "spirv_fn_<id>" as the function name if the input 352 // SPIR-V blob does not contain a name for it. We should use a more clear 353 // indication for such case rather than relying on naming details. 354 if (!parsedFunc.getName().starts_with("spirv_fn_")) 355 return emitError(unknownLoc, 356 "function name mismatch between OpEntryPoint " 357 "and OpFunction with <id> ") 358 << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); 359 parsedFunc.setName(fnName); 360 } 361 SmallVector<Attribute, 4> interface; 362 while (wordIndex < words.size()) { 363 auto arg = getGlobalVariable(words[wordIndex]); 364 if (!arg) { 365 return emitError(unknownLoc, "undefined result <id> ") 366 << words[wordIndex] << " while decoding OpEntryPoint"; 367 } 368 interface.push_back(SymbolRefAttr::get(arg.getOperation())); 369 wordIndex++; 370 } 371 opBuilder.create<spirv::EntryPointOp>( 372 unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), 373 opBuilder.getArrayAttr(interface)); 374 return success(); 375 } 376 377 template <> 378 LogicalResult 379 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { 380 unsigned wordIndex = 0; 381 if (wordIndex >= words.size()) { 382 return emitError(unknownLoc, 383 "missing function result <id> in OpExecutionMode"); 384 } 385 // Get the function <id> to get the name of the function 386 auto fnID = words[wordIndex++]; 387 auto fn = getFunction(fnID); 388 if (!fn) { 389 return emitError(unknownLoc, "no function matching <id> ") << fnID; 390 } 391 // Get the Execution mode 392 if (wordIndex >= words.size()) { 393 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); 394 } 395 auto execMode = spirv::ExecutionModeAttr::get( 396 context, static_cast<spirv::ExecutionMode>(words[wordIndex++])); 397 398 // Get the values 399 SmallVector<Attribute, 4> attrListElems; 400 while (wordIndex < words.size()) { 401 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); 402 } 403 auto values = opBuilder.getArrayAttr(attrListElems); 404 opBuilder.create<spirv::ExecutionModeOp>( 405 unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), 406 execMode, values); 407 return success(); 408 } 409 410 template <> 411 LogicalResult 412 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) { 413 if (operands.size() < 3) { 414 return emitError(unknownLoc, 415 "OpFunctionCall must have at least 3 operands"); 416 } 417 418 Type resultType = getType(operands[0]); 419 if (!resultType) { 420 return emitError(unknownLoc, "undefined result type from <id> ") 421 << operands[0]; 422 } 423 424 // Use null type to mean no result type. 425 if (isVoidType(resultType)) 426 resultType = nullptr; 427 428 auto resultID = operands[1]; 429 auto functionID = operands[2]; 430 431 auto functionName = getFunctionSymbol(functionID); 432 433 SmallVector<Value, 4> arguments; 434 for (auto operand : llvm::drop_begin(operands, 3)) { 435 auto value = getValue(operand); 436 if (!value) { 437 return emitError(unknownLoc, "unknown <id> ") 438 << operand << " used by OpFunctionCall"; 439 } 440 arguments.push_back(value); 441 } 442 443 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>( 444 unknownLoc, resultType, 445 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); 446 447 if (resultType) 448 valueMap[resultID] = opFunctionCall.getResult(0); 449 return success(); 450 } 451 452 template <> 453 LogicalResult 454 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) { 455 SmallVector<Type, 1> resultTypes; 456 size_t wordIndex = 0; 457 SmallVector<Value, 4> operands; 458 SmallVector<NamedAttribute, 4> attributes; 459 460 if (wordIndex < words.size()) { 461 auto arg = getValue(words[wordIndex]); 462 463 if (!arg) { 464 return emitError(unknownLoc, "unknown result <id> : ") 465 << words[wordIndex]; 466 } 467 468 operands.push_back(arg); 469 wordIndex++; 470 } 471 472 if (wordIndex < words.size()) { 473 auto arg = getValue(words[wordIndex]); 474 475 if (!arg) { 476 return emitError(unknownLoc, "unknown result <id> : ") 477 << words[wordIndex]; 478 } 479 480 operands.push_back(arg); 481 wordIndex++; 482 } 483 484 bool isAlignedAttr = false; 485 486 if (wordIndex < words.size()) { 487 auto attrValue = words[wordIndex++]; 488 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>( 489 static_cast<spirv::MemoryAccess>(attrValue)); 490 attributes.push_back( 491 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr)); 492 isAlignedAttr = (attrValue == 2); 493 } 494 495 if (isAlignedAttr && wordIndex < words.size()) { 496 attributes.push_back(opBuilder.getNamedAttr( 497 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 498 } 499 500 if (wordIndex < words.size()) { 501 auto attrValue = words[wordIndex++]; 502 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>( 503 static_cast<spirv::MemoryAccess>(attrValue)); 504 attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr)); 505 } 506 507 if (wordIndex < words.size()) { 508 attributes.push_back(opBuilder.getNamedAttr( 509 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 510 } 511 512 if (wordIndex != words.size()) { 513 return emitError(unknownLoc, 514 "found more operands than expected when deserializing " 515 "spirv::CopyMemoryOp, only ") 516 << wordIndex << " of " << words.size() << " processed"; 517 } 518 519 Location loc = createFileLineColLoc(opBuilder); 520 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes); 521 522 return success(); 523 } 524 525 template <> 526 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>( 527 ArrayRef<uint32_t> words) { 528 if (words.size() != 4) { 529 return emitError(unknownLoc, 530 "expected 4 words in GenericCastToPtrExplicitOp" 531 " but got : ") 532 << words.size(); 533 } 534 SmallVector<Type, 1> resultTypes; 535 SmallVector<Value, 4> operands; 536 uint32_t valueID = 0; 537 auto type = getType(words[0]); 538 539 if (!type) 540 return emitError(unknownLoc, "unknown type result <id> : ") << words[0]; 541 resultTypes.push_back(type); 542 543 valueID = words[1]; 544 545 auto arg = getValue(words[2]); 546 if (!arg) 547 return emitError(unknownLoc, "unknown result <id> : ") << words[2]; 548 operands.push_back(arg); 549 550 Location loc = createFileLineColLoc(opBuilder); 551 Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>( 552 loc, resultTypes, operands); 553 valueMap[valueID] = op->getResult(0); 554 return success(); 555 } 556 557 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and 558 // various Deserializer::processOp<...>() specializations. 559 #define GET_DESERIALIZATION_FNS 560 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 561 562 } // namespace spirv 563 } // namespace mlir 564