1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the serialization methods for MLIR SPIR-V module ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/IR/RegionGraphTraits.h" 18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 19 #include "llvm/ADT/DepthFirstIterator.h" 20 #include "llvm/ADT/StringExtras.h" 21 #include "llvm/Support/Debug.h" 22 23 #define DEBUG_TYPE "spirv-serialization" 24 25 using namespace mlir; 26 27 /// A pre-order depth-first visitor function for processing basic blocks. 28 /// 29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order 30 /// depth-first manner and calls `blockHandler` on each block. Skips handling 31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` 32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s 33 /// successors. 34 /// 35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order 36 /// of blocks in a function must satisfy the rule that blocks appear before 37 /// all blocks they dominate." This can be achieved by a pre-order CFG 38 /// traversal algorithm. To make the serialization output more logical and 39 /// readable to human, we perform depth-first CFG traversal and delay the 40 /// serialization of the merge block and the continue block, if exists, until 41 /// after all other blocks have been processed. 42 static LogicalResult 43 visitInPrettyBlockOrder(Block *headerBlock, 44 function_ref<LogicalResult(Block *)> blockHandler, 45 bool skipHeader = false, BlockRange skipBlocks = {}) { 46 llvm::df_iterator_default_set<Block *, 4> doneBlocks; 47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); 48 49 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { 50 if (skipHeader && block == headerBlock) 51 continue; 52 if (failed(blockHandler(block))) 53 return failure(); 54 } 55 return success(); 56 } 57 58 namespace mlir { 59 namespace spirv { 60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { 61 if (auto resultID = 62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) { 63 valueIDMap[op.getResult()] = resultID; 64 return success(); 65 } 66 return failure(); 67 } 68 69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { 70 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(), 71 /*isSpec=*/true)) { 72 // Emit the OpDecorate instruction for SpecId. 73 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { 74 auto val = static_cast<uint32_t>(specID.getInt()); 75 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val}))) 76 return failure(); 77 } 78 79 specConstIDMap[op.getSymName()] = resultID; 80 return processName(resultID, op.getSymName()); 81 } 82 return failure(); 83 } 84 85 LogicalResult 86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { 87 uint32_t typeID = 0; 88 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 89 return failure(); 90 } 91 92 auto resultID = getNextID(); 93 94 SmallVector<uint32_t, 8> operands; 95 operands.push_back(typeID); 96 operands.push_back(resultID); 97 98 auto constituents = op.getConstituents(); 99 100 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 101 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]); 102 103 auto constituentName = constituent.getValue(); 104 auto constituentID = getSpecConstID(constituentName); 105 106 if (!constituentID) { 107 return op.emitError("unknown result <id> for specialization constant ") 108 << constituentName; 109 } 110 111 operands.push_back(constituentID); 112 } 113 114 encodeInstructionInto(typesGlobalValues, 115 spirv::Opcode::OpSpecConstantComposite, operands); 116 specConstIDMap[op.getSymName()] = resultID; 117 118 return processName(resultID, op.getSymName()); 119 } 120 121 LogicalResult 122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { 123 uint32_t typeID = 0; 124 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 125 return failure(); 126 } 127 128 auto resultID = getNextID(); 129 130 SmallVector<uint32_t, 8> operands; 131 operands.push_back(typeID); 132 operands.push_back(resultID); 133 134 Block &block = op.getRegion().getBlocks().front(); 135 Operation &enclosedOp = block.getOperations().front(); 136 137 std::string enclosedOpName; 138 llvm::raw_string_ostream rss(enclosedOpName); 139 rss << "Op" << enclosedOp.getName().stripDialect(); 140 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName); 141 142 if (!enclosedOpcode) { 143 op.emitError("Couldn't find op code for op ") 144 << enclosedOp.getName().getStringRef(); 145 return failure(); 146 } 147 148 operands.push_back(static_cast<uint32_t>(*enclosedOpcode)); 149 150 // Append operands to the enclosed op to the list of operands. 151 for (Value operand : enclosedOp.getOperands()) { 152 uint32_t id = getValueID(operand); 153 assert(id && "use before def!"); 154 operands.push_back(id); 155 } 156 157 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, 158 operands); 159 valueIDMap[op.getResult()] = resultID; 160 161 return success(); 162 } 163 164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { 165 auto undefType = op.getType(); 166 auto &id = undefValIDMap[undefType]; 167 if (!id) { 168 id = getNextID(); 169 uint32_t typeID = 0; 170 if (failed(processType(op.getLoc(), undefType, typeID))) 171 return failure(); 172 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, 173 {typeID, id}); 174 } 175 valueIDMap[op.getResult()] = id; 176 return success(); 177 } 178 179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { 180 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { 181 uint32_t argTypeID = 0; 182 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { 183 return failure(); 184 } 185 auto argValueID = getNextID(); 186 187 // Process decoration attributes of arguments. 188 auto funcOp = cast<FunctionOpInterface>(*op); 189 for (auto argAttr : funcOp.getArgAttrs(idx)) { 190 if (argAttr.getName() != DecorationAttr::name) 191 continue; 192 193 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) { 194 if (failed(processDecorationAttr(op->getLoc(), argValueID, 195 decAttr.getValue(), decAttr))) 196 return failure(); 197 } 198 } 199 200 valueIDMap[arg] = argValueID; 201 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, 202 {argTypeID, argValueID}); 203 } 204 return success(); 205 } 206 207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { 208 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); 209 assert(functionHeader.empty() && functionBody.empty()); 210 211 uint32_t fnTypeID = 0; 212 // Generate type of the function. 213 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID))) 214 return failure(); 215 216 // Add the function definition. 217 SmallVector<uint32_t, 4> operands; 218 uint32_t resTypeID = 0; 219 auto resultTypes = op.getFunctionType().getResults(); 220 if (resultTypes.size() > 1) { 221 return op.emitError("cannot serialize function with multiple return types"); 222 } 223 if (failed(processType(op.getLoc(), 224 (resultTypes.empty() ? getVoidType() : resultTypes[0]), 225 resTypeID))) { 226 return failure(); 227 } 228 operands.push_back(resTypeID); 229 auto funcID = getOrCreateFunctionID(op.getName()); 230 operands.push_back(funcID); 231 operands.push_back(static_cast<uint32_t>(op.getFunctionControl())); 232 operands.push_back(fnTypeID); 233 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); 234 235 // Add function name. 236 if (failed(processName(funcID, op.getName()))) { 237 return failure(); 238 } 239 // Handle external functions with linkage_attributes(LinkageAttributes) 240 // differently. 241 auto linkageAttr = op.getLinkageAttributes(); 242 auto hasImportLinkage = 243 linkageAttr && (linkageAttr.value().getLinkageType().getValue() == 244 spirv::LinkageType::Import); 245 if (op.isExternal() && !hasImportLinkage) { 246 return op.emitError( 247 "'spirv.module' cannot contain external functions " 248 "without 'Import' linkage_attributes (LinkageAttributes)"); 249 } 250 if (op.isExternal() && hasImportLinkage) { 251 // Add an entry block to set up the block arguments 252 // to match the signature of the function. 253 // This is to generate OpFunctionParameter for functions with 254 // LinkageAttributes. 255 // WARNING: This operation has side-effect, it essentially adds a body 256 // to the func. Hence, making it not external anymore (isExternal() 257 // is going to return false for this function from now on) 258 // Hence, we'll remove the body once we are done with the serialization. 259 op.addEntryBlock(); 260 if (failed(processFuncParameter(op))) 261 return failure(); 262 // Don't need to process the added block, there is nothing to process, 263 // the fake body was added just to get the arguments, remove the body, 264 // since it's use is done. 265 op.eraseBody(); 266 } else { 267 if (failed(processFuncParameter(op))) 268 return failure(); 269 270 // Some instructions (e.g., OpVariable) in a function must be in the first 271 // block in the function. These instructions will be put in 272 // functionHeader. Thus, we put the label in functionHeader first, and 273 // omit it from the first block. OpLabel only needs to be added for 274 // functions with body (including empty body). Since, we added a fake body 275 // for functions with 'Import' Linkage attributes, these functions are 276 // essentially function delcaration, so they should not have OpLabel and a 277 // terminating instruction. That's why we skipped it for those functions. 278 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, 279 {getOrCreateBlockID(&op.front())}); 280 if (failed(processBlock(&op.front(), /*omitLabel=*/true))) 281 return failure(); 282 if (failed(visitInPrettyBlockOrder( 283 &op.front(), [&](Block *block) { return processBlock(block); }, 284 /*skipHeader=*/true))) { 285 return failure(); 286 } 287 288 // There might be OpPhi instructions who have value references needing to 289 // fix. 290 for (const auto &deferredValue : deferredPhiValues) { 291 Value value = deferredValue.first; 292 uint32_t id = getValueID(value); 293 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value 294 << " to id = " << id << '\n'); 295 assert(id && "OpPhi references undefined value!"); 296 for (size_t offset : deferredValue.second) 297 functionBody[offset] = id; 298 } 299 deferredPhiValues.clear(); 300 } 301 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() 302 << "' --\n"); 303 // Insert Decorations based on Function Attributes. 304 // Only attributes we should be considering for decoration are the 305 // ::mlir::spirv::Decoration attributes. 306 307 for (auto attr : op->getAttrs()) { 308 // Only generate OpDecorate op for spirv::Decoration attributes. 309 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>( 310 llvm::convertToCamelFromSnakeCase(attr.getName().strref(), 311 /*capitalizeFirst=*/true)); 312 if (isValidDecoration != std::nullopt) { 313 if (failed(processDecoration(op.getLoc(), funcID, attr))) { 314 return failure(); 315 } 316 } 317 } 318 // Insert OpFunctionEnd. 319 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); 320 321 functions.append(functionHeader.begin(), functionHeader.end()); 322 functions.append(functionBody.begin(), functionBody.end()); 323 functionHeader.clear(); 324 functionBody.clear(); 325 326 return success(); 327 } 328 329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { 330 SmallVector<uint32_t, 4> operands; 331 SmallVector<StringRef, 2> elidedAttrs; 332 uint32_t resultID = 0; 333 uint32_t resultTypeID = 0; 334 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { 335 return failure(); 336 } 337 operands.push_back(resultTypeID); 338 resultID = getNextID(); 339 valueIDMap[op.getResult()] = resultID; 340 operands.push_back(resultID); 341 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); 342 if (attr) { 343 operands.push_back( 344 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue())); 345 } 346 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 347 for (auto arg : op.getODSOperands(0)) { 348 auto argID = getValueID(arg); 349 if (!argID) { 350 return emitError(op.getLoc(), "operand 0 has a use before def"); 351 } 352 operands.push_back(argID); 353 } 354 if (failed(emitDebugLine(functionHeader, op.getLoc()))) 355 return failure(); 356 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); 357 for (auto attr : op->getAttrs()) { 358 if (llvm::any_of(elidedAttrs, [&](StringRef elided) { 359 return attr.getName() == elided; 360 })) { 361 continue; 362 } 363 if (failed(processDecoration(op.getLoc(), resultID, attr))) { 364 return failure(); 365 } 366 } 367 return success(); 368 } 369 370 LogicalResult 371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { 372 // Get TypeID. 373 uint32_t resultTypeID = 0; 374 SmallVector<StringRef, 4> elidedAttrs; 375 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) { 376 return failure(); 377 } 378 379 elidedAttrs.push_back("type"); 380 SmallVector<uint32_t, 4> operands; 381 operands.push_back(resultTypeID); 382 auto resultID = getNextID(); 383 384 // Encode the name. 385 auto varName = varOp.getSymName(); 386 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 387 if (failed(processName(resultID, varName))) { 388 return failure(); 389 } 390 globalVarIDMap[varName] = resultID; 391 operands.push_back(resultID); 392 393 // Encode StorageClass. 394 operands.push_back(static_cast<uint32_t>(varOp.storageClass())); 395 396 // Encode initialization. 397 StringRef initAttrName = varOp.getInitializerAttrName().getValue(); 398 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) { 399 uint32_t initializerID = 0; 400 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName); 401 Operation *initOp = SymbolTable::lookupNearestSymbolFrom( 402 varOp->getParentOp(), initRef.getAttr()); 403 404 // Check if initializer is GlobalVariable or SpecConstant* cases. 405 if (isa<spirv::GlobalVariableOp>(initOp)) 406 initializerID = getVariableID(*initSymbolName); 407 else 408 initializerID = getSpecConstID(*initSymbolName); 409 410 if (!initializerID) 411 return emitError(varOp.getLoc(), 412 "invalid usage of undefined variable as initializer"); 413 414 operands.push_back(initializerID); 415 elidedAttrs.push_back(initAttrName); 416 } 417 418 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc()))) 419 return failure(); 420 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands); 421 elidedAttrs.push_back(initAttrName); 422 423 // Encode decorations. 424 for (auto attr : varOp->getAttrs()) { 425 if (llvm::any_of(elidedAttrs, [&](StringRef elided) { 426 return attr.getName() == elided; 427 })) { 428 continue; 429 } 430 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { 431 return failure(); 432 } 433 } 434 return success(); 435 } 436 437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { 438 // Assign <id>s to all blocks so that branches inside the SelectionOp can 439 // resolve properly. 440 auto &body = selectionOp.getBody(); 441 for (Block &block : body) 442 getOrCreateBlockID(&block); 443 444 auto *headerBlock = selectionOp.getHeaderBlock(); 445 auto *mergeBlock = selectionOp.getMergeBlock(); 446 auto headerID = getBlockID(headerBlock); 447 auto mergeID = getBlockID(mergeBlock); 448 auto loc = selectionOp.getLoc(); 449 450 // This SelectionOp is in some MLIR block with preceding and following ops. In 451 // the binary format, it should reside in separate SPIR-V blocks from its 452 // preceding and following ops. So we need to emit unconditional branches to 453 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal 454 // flow afterwards. 455 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 456 457 // Emit the selection header block, which dominates all other blocks, first. 458 // We need to emit an OpSelectionMerge instruction before the selection header 459 // block's terminator. 460 auto emitSelectionMerge = [&]() { 461 if (failed(emitDebugLine(functionBody, loc))) 462 return failure(); 463 lastProcessedWasMergeInst = true; 464 encodeInstructionInto( 465 functionBody, spirv::Opcode::OpSelectionMerge, 466 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())}); 467 return success(); 468 }; 469 if (failed( 470 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge))) 471 return failure(); 472 473 // Process all blocks with a depth-first visitor starting from the header 474 // block. The selection header block and merge block are skipped by this 475 // visitor. 476 if (failed(visitInPrettyBlockOrder( 477 headerBlock, [&](Block *block) { return processBlock(block); }, 478 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) 479 return failure(); 480 481 // There is nothing to do for the merge block in the selection, which just 482 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel 483 // instruction to start a new SPIR-V block for ops following this SelectionOp. 484 // The block should use the <id> for the merge block. 485 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 486 LLVM_DEBUG(llvm::dbgs() << "done merge "); 487 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); 488 LLVM_DEBUG(llvm::dbgs() << "\n"); 489 return success(); 490 } 491 492 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { 493 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve 494 // properly. We don't need to assign for the entry block, which is just for 495 // satisfying MLIR region's structural requirement. 496 auto &body = loopOp.getBody(); 497 for (Block &block : llvm::drop_begin(body)) 498 getOrCreateBlockID(&block); 499 500 auto *headerBlock = loopOp.getHeaderBlock(); 501 auto *continueBlock = loopOp.getContinueBlock(); 502 auto *mergeBlock = loopOp.getMergeBlock(); 503 auto headerID = getBlockID(headerBlock); 504 auto continueID = getBlockID(continueBlock); 505 auto mergeID = getBlockID(mergeBlock); 506 auto loc = loopOp.getLoc(); 507 508 // This LoopOp is in some MLIR block with preceding and following ops. In the 509 // binary format, it should reside in separate SPIR-V blocks from its 510 // preceding and following ops. So we need to emit unconditional branches to 511 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow 512 // afterwards. 513 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 514 515 // LoopOp's entry block is just there for satisfying MLIR's structural 516 // requirements so we omit it and start serialization from the loop header 517 // block. 518 519 // Emit the loop header block, which dominates all other blocks, first. We 520 // need to emit an OpLoopMerge instruction before the loop header block's 521 // terminator. 522 auto emitLoopMerge = [&]() { 523 if (failed(emitDebugLine(functionBody, loc))) 524 return failure(); 525 lastProcessedWasMergeInst = true; 526 encodeInstructionInto( 527 functionBody, spirv::Opcode::OpLoopMerge, 528 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())}); 529 return success(); 530 }; 531 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) 532 return failure(); 533 534 // Process all blocks with a depth-first visitor starting from the header 535 // block. The loop header block, loop continue block, and loop merge block are 536 // skipped by this visitor and handled later in this function. 537 if (failed(visitInPrettyBlockOrder( 538 headerBlock, [&](Block *block) { return processBlock(block); }, 539 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) 540 return failure(); 541 542 // We have handled all other blocks. Now get to the loop continue block. 543 if (failed(processBlock(continueBlock))) 544 return failure(); 545 546 // There is nothing to do for the merge block in the loop, which just contains 547 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction 548 // to start a new SPIR-V block for ops following this LoopOp. The block should 549 // use the <id> for the merge block. 550 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 551 LLVM_DEBUG(llvm::dbgs() << "done merge "); 552 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); 553 LLVM_DEBUG(llvm::dbgs() << "\n"); 554 return success(); 555 } 556 557 LogicalResult Serializer::processBranchConditionalOp( 558 spirv::BranchConditionalOp condBranchOp) { 559 auto conditionID = getValueID(condBranchOp.getCondition()); 560 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); 561 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); 562 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; 563 564 if (auto weights = condBranchOp.getBranchWeights()) { 565 for (auto val : weights->getValue()) 566 arguments.push_back(cast<IntegerAttr>(val).getInt()); 567 } 568 569 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc()))) 570 return failure(); 571 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, 572 arguments); 573 return success(); 574 } 575 576 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { 577 if (failed(emitDebugLine(functionBody, branchOp.getLoc()))) 578 return failure(); 579 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 580 {getOrCreateBlockID(branchOp.getTarget())}); 581 return success(); 582 } 583 584 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { 585 auto varName = addressOfOp.getVariable(); 586 auto variableID = getVariableID(varName); 587 if (!variableID) { 588 return addressOfOp.emitError("unknown result <id> for variable ") 589 << varName; 590 } 591 valueIDMap[addressOfOp.getPointer()] = variableID; 592 return success(); 593 } 594 595 LogicalResult 596 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { 597 auto constName = referenceOfOp.getSpecConst(); 598 auto constID = getSpecConstID(constName); 599 if (!constID) { 600 return referenceOfOp.emitError( 601 "unknown result <id> for specialization constant ") 602 << constName; 603 } 604 valueIDMap[referenceOfOp.getReference()] = constID; 605 return success(); 606 } 607 608 template <> 609 LogicalResult 610 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { 611 SmallVector<uint32_t, 4> operands; 612 // Add the ExecutionModel. 613 operands.push_back(static_cast<uint32_t>(op.getExecutionModel())); 614 // Add the function <id>. 615 auto funcID = getFunctionID(op.getFn()); 616 if (!funcID) { 617 return op.emitError("missing <id> for function ") 618 << op.getFn() 619 << "; function needs to be defined before spirv.EntryPoint is " 620 "serialized"; 621 } 622 operands.push_back(funcID); 623 // Add the name of the function. 624 spirv::encodeStringLiteralInto(operands, op.getFn()); 625 626 // Add the interface values. 627 if (auto interface = op.getInterface()) { 628 for (auto var : interface.getValue()) { 629 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue()); 630 if (!id) { 631 return op.emitError( 632 "referencing undefined global variable." 633 "spirv.EntryPoint is at the end of spirv.module. All " 634 "referenced variables should already be defined"); 635 } 636 operands.push_back(id); 637 } 638 } 639 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); 640 return success(); 641 } 642 643 template <> 644 LogicalResult 645 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { 646 SmallVector<uint32_t, 4> operands; 647 // Add the function <id>. 648 auto funcID = getFunctionID(op.getFn()); 649 if (!funcID) { 650 return op.emitError("missing <id> for function ") 651 << op.getFn() 652 << "; function needs to be serialized before ExecutionModeOp is " 653 "serialized"; 654 } 655 operands.push_back(funcID); 656 // Add the ExecutionMode. 657 operands.push_back(static_cast<uint32_t>(op.getExecutionMode())); 658 659 // Serialize values if any. 660 auto values = op.getValues(); 661 if (values) { 662 for (auto &intVal : values.getValue()) { 663 operands.push_back(static_cast<uint32_t>( 664 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue())); 665 } 666 } 667 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, 668 operands); 669 return success(); 670 } 671 672 template <> 673 LogicalResult 674 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { 675 auto funcName = op.getCallee(); 676 uint32_t resTypeID = 0; 677 678 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); 679 if (failed(processType(op.getLoc(), resultTy, resTypeID))) 680 return failure(); 681 682 auto funcID = getOrCreateFunctionID(funcName); 683 auto funcCallID = getNextID(); 684 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; 685 686 for (auto value : op.getArguments()) { 687 auto valueID = getValueID(value); 688 assert(valueID && "cannot find a value for spirv.FunctionCall"); 689 operands.push_back(valueID); 690 } 691 692 if (!isa<NoneType>(resultTy)) 693 valueIDMap[op.getResult(0)] = funcCallID; 694 695 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); 696 return success(); 697 } 698 699 template <> 700 LogicalResult 701 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { 702 SmallVector<uint32_t, 4> operands; 703 SmallVector<StringRef, 2> elidedAttrs; 704 705 for (Value operand : op->getOperands()) { 706 auto id = getValueID(operand); 707 assert(id && "use before def!"); 708 operands.push_back(id); 709 } 710 711 StringAttr memoryAccess = op.getMemoryAccessAttrName(); 712 if (auto attr = op->getAttr(memoryAccess)) { 713 operands.push_back( 714 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); 715 } 716 717 elidedAttrs.push_back(memoryAccess.strref()); 718 719 StringAttr alignment = op.getAlignmentAttrName(); 720 if (auto attr = op->getAttr(alignment)) { 721 operands.push_back(static_cast<uint32_t>( 722 cast<IntegerAttr>(attr).getValue().getZExtValue())); 723 } 724 725 elidedAttrs.push_back(alignment.strref()); 726 727 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); 728 if (auto attr = op->getAttr(sourceMemoryAccess)) { 729 operands.push_back( 730 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); 731 } 732 733 elidedAttrs.push_back(sourceMemoryAccess.strref()); 734 735 StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); 736 if (auto attr = op->getAttr(sourceAlignment)) { 737 operands.push_back(static_cast<uint32_t>( 738 cast<IntegerAttr>(attr).getValue().getZExtValue())); 739 } 740 741 elidedAttrs.push_back(sourceAlignment.strref()); 742 if (failed(emitDebugLine(functionBody, op.getLoc()))) 743 return failure(); 744 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); 745 746 return success(); 747 } 748 template <> 749 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>( 750 spirv::GenericCastToPtrExplicitOp op) { 751 SmallVector<uint32_t, 4> operands; 752 Type resultTy; 753 Location loc = op->getLoc(); 754 uint32_t resultTypeID = 0; 755 uint32_t resultID = 0; 756 resultTy = op->getResult(0).getType(); 757 if (failed(processType(loc, resultTy, resultTypeID))) 758 return failure(); 759 operands.push_back(resultTypeID); 760 761 resultID = getNextID(); 762 operands.push_back(resultID); 763 valueIDMap[op->getResult(0)] = resultID; 764 765 for (Value operand : op->getOperands()) 766 operands.push_back(getValueID(operand)); 767 spirv::StorageClass resultStorage = 768 cast<spirv::PointerType>(resultTy).getStorageClass(); 769 operands.push_back(static_cast<uint32_t>(resultStorage)); 770 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, 771 operands); 772 return success(); 773 } 774 775 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and 776 // various Serializer::processOp<...>() specializations. 777 #define GET_SERIALIZATION_FNS 778 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 779 780 } // namespace spirv 781 } // namespace mlir 782