1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the serialization methods for MLIR SPIR-V module ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/IR/RegionGraphTraits.h" 18 #include "mlir/Support/LogicalResult.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/DepthFirstIterator.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "spirv-serialization" 25 26 using namespace mlir; 27 28 /// A pre-order depth-first visitor function for processing basic blocks. 29 /// 30 /// Visits the basic blocks starting from the given `headerBlock` in pre-order 31 /// depth-first manner and calls `blockHandler` on each block. Skips handling 32 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` 33 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s 34 /// successors. 35 /// 36 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order 37 /// of blocks in a function must satisfy the rule that blocks appear before 38 /// all blocks they dominate." This can be achieved by a pre-order CFG 39 /// traversal algorithm. To make the serialization output more logical and 40 /// readable to human, we perform depth-first CFG traversal and delay the 41 /// serialization of the merge block and the continue block, if exists, until 42 /// after all other blocks have been processed. 43 static LogicalResult 44 visitInPrettyBlockOrder(Block *headerBlock, 45 function_ref<LogicalResult(Block *)> blockHandler, 46 bool skipHeader = false, BlockRange skipBlocks = {}) { 47 llvm::df_iterator_default_set<Block *, 4> doneBlocks; 48 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); 49 50 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { 51 if (skipHeader && block == headerBlock) 52 continue; 53 if (failed(blockHandler(block))) 54 return failure(); 55 } 56 return success(); 57 } 58 59 namespace mlir { 60 namespace spirv { 61 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { 62 if (auto resultID = 63 prepareConstant(op.getLoc(), op.getType(), op.getValue())) { 64 valueIDMap[op.getResult()] = resultID; 65 return success(); 66 } 67 return failure(); 68 } 69 70 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { 71 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(), 72 /*isSpec=*/true)) { 73 // Emit the OpDecorate instruction for SpecId. 74 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { 75 auto val = static_cast<uint32_t>(specID.getInt()); 76 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val}))) 77 return failure(); 78 } 79 80 specConstIDMap[op.getSymName()] = resultID; 81 return processName(resultID, op.getSymName()); 82 } 83 return failure(); 84 } 85 86 LogicalResult 87 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { 88 uint32_t typeID = 0; 89 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 90 return failure(); 91 } 92 93 auto resultID = getNextID(); 94 95 SmallVector<uint32_t, 8> operands; 96 operands.push_back(typeID); 97 operands.push_back(resultID); 98 99 auto constituents = op.getConstituents(); 100 101 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 102 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]); 103 104 auto constituentName = constituent.getValue(); 105 auto constituentID = getSpecConstID(constituentName); 106 107 if (!constituentID) { 108 return op.emitError("unknown result <id> for specialization constant ") 109 << constituentName; 110 } 111 112 operands.push_back(constituentID); 113 } 114 115 encodeInstructionInto(typesGlobalValues, 116 spirv::Opcode::OpSpecConstantComposite, operands); 117 specConstIDMap[op.getSymName()] = resultID; 118 119 return processName(resultID, op.getSymName()); 120 } 121 122 LogicalResult 123 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { 124 uint32_t typeID = 0; 125 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 126 return failure(); 127 } 128 129 auto resultID = getNextID(); 130 131 SmallVector<uint32_t, 8> operands; 132 operands.push_back(typeID); 133 operands.push_back(resultID); 134 135 Block &block = op.getRegion().getBlocks().front(); 136 Operation &enclosedOp = block.getOperations().front(); 137 138 std::string enclosedOpName; 139 llvm::raw_string_ostream rss(enclosedOpName); 140 rss << "Op" << enclosedOp.getName().stripDialect(); 141 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); 142 143 if (!enclosedOpcode) { 144 op.emitError("Couldn't find op code for op ") 145 << enclosedOp.getName().getStringRef(); 146 return failure(); 147 } 148 149 operands.push_back(static_cast<uint32_t>(*enclosedOpcode)); 150 151 // Append operands to the enclosed op to the list of operands. 152 for (Value operand : enclosedOp.getOperands()) { 153 uint32_t id = getValueID(operand); 154 assert(id && "use before def!"); 155 operands.push_back(id); 156 } 157 158 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, 159 operands); 160 valueIDMap[op.getResult()] = resultID; 161 162 return success(); 163 } 164 165 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { 166 auto undefType = op.getType(); 167 auto &id = undefValIDMap[undefType]; 168 if (!id) { 169 id = getNextID(); 170 uint32_t typeID = 0; 171 if (failed(processType(op.getLoc(), undefType, typeID))) 172 return failure(); 173 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, 174 {typeID, id}); 175 } 176 valueIDMap[op.getResult()] = id; 177 return success(); 178 } 179 180 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) { 181 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { 182 uint32_t argTypeID = 0; 183 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { 184 return failure(); 185 } 186 auto argValueID = getNextID(); 187 188 // Process decoration attributes of arguments. 189 auto funcOp = cast<FunctionOpInterface>(*op); 190 for (auto argAttr : funcOp.getArgAttrs(idx)) { 191 if (argAttr.getName() != DecorationAttr::name) 192 continue; 193 194 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) { 195 if (failed(processDecorationAttr(op->getLoc(), argValueID, 196 decAttr.getValue(), decAttr))) 197 return failure(); 198 } 199 } 200 201 valueIDMap[arg] = argValueID; 202 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, 203 {argTypeID, argValueID}); 204 } 205 return success(); 206 } 207 208 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { 209 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); 210 assert(functionHeader.empty() && functionBody.empty()); 211 212 uint32_t fnTypeID = 0; 213 // Generate type of the function. 214 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID))) 215 return failure(); 216 217 // Add the function definition. 218 SmallVector<uint32_t, 4> operands; 219 uint32_t resTypeID = 0; 220 auto resultTypes = op.getFunctionType().getResults(); 221 if (resultTypes.size() > 1) { 222 return op.emitError("cannot serialize function with multiple return types"); 223 } 224 if (failed(processType(op.getLoc(), 225 (resultTypes.empty() ? getVoidType() : resultTypes[0]), 226 resTypeID))) { 227 return failure(); 228 } 229 operands.push_back(resTypeID); 230 auto funcID = getOrCreateFunctionID(op.getName()); 231 operands.push_back(funcID); 232 operands.push_back(static_cast<uint32_t>(op.getFunctionControl())); 233 operands.push_back(fnTypeID); 234 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); 235 236 // Add function name. 237 if (failed(processName(funcID, op.getName()))) { 238 return failure(); 239 } 240 // Handle external functions with linkage_attributes(LinkageAttributes) 241 // differently. 242 auto linkageAttr = op.getLinkageAttributes(); 243 auto hasImportLinkage = 244 linkageAttr && (linkageAttr.value().getLinkageType().getValue() == 245 spirv::LinkageType::Import); 246 if (op.isExternal() && !hasImportLinkage) { 247 return op.emitError( 248 "'spirv.module' cannot contain external functions " 249 "without 'Import' linkage_attributes (LinkageAttributes)"); 250 } 251 if (op.isExternal() && hasImportLinkage) { 252 // Add an entry block to set up the block arguments 253 // to match the signature of the function. 254 // This is to generate OpFunctionParameter for functions with 255 // LinkageAttributes. 256 // WARNING: This operation has side-effect, it essentially adds a body 257 // to the func. Hence, making it not external anymore (isExternal() 258 // is going to return false for this function from now on) 259 // Hence, we'll remove the body once we are done with the serialization. 260 op.addEntryBlock(); 261 if (failed(processFuncParameter(op))) 262 return failure(); 263 // Don't need to process the added block, there is nothing to process, 264 // the fake body was added just to get the arguments, remove the body, 265 // since it's use is done. 266 op.eraseBody(); 267 } else { 268 if (failed(processFuncParameter(op))) 269 return failure(); 270 271 // Some instructions (e.g., OpVariable) in a function must be in the first 272 // block in the function. These instructions will be put in 273 // functionHeader. Thus, we put the label in functionHeader first, and 274 // omit it from the first block. OpLabel only needs to be added for 275 // functions with body (including empty body). Since, we added a fake body 276 // for functions with 'Import' Linkage attributes, these functions are 277 // essentially function delcaration, so they should not have OpLabel and a 278 // terminating instruction. That's why we skipped it for those functions. 279 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, 280 {getOrCreateBlockID(&op.front())}); 281 if (failed(processBlock(&op.front(), /*omitLabel=*/true))) 282 return failure(); 283 if (failed(visitInPrettyBlockOrder( 284 &op.front(), [&](Block *block) { return processBlock(block); }, 285 /*skipHeader=*/true))) { 286 return failure(); 287 } 288 289 // There might be OpPhi instructions who have value references needing to 290 // fix. 291 for (const auto &deferredValue : deferredPhiValues) { 292 Value value = deferredValue.first; 293 uint32_t id = getValueID(value); 294 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value 295 << " to id = " << id << '\n'); 296 assert(id && "OpPhi references undefined value!"); 297 for (size_t offset : deferredValue.second) 298 functionBody[offset] = id; 299 } 300 deferredPhiValues.clear(); 301 } 302 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() 303 << "' --\n"); 304 // Insert Decorations based on Function Attributes. 305 // Only attributes we should be considering for decoration are the 306 // ::mlir::spirv::Decoration attributes. 307 308 for (auto attr : op->getAttrs()) { 309 // Only generate OpDecorate op for spirv::Decoration attributes. 310 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>( 311 llvm::convertToCamelFromSnakeCase(attr.getName().strref(), 312 /*capitalizeFirst=*/true)); 313 if (isValidDecoration != std::nullopt) { 314 if (failed(processDecoration(op.getLoc(), funcID, attr))) { 315 return failure(); 316 } 317 } 318 } 319 // Insert OpFunctionEnd. 320 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); 321 322 functions.append(functionHeader.begin(), functionHeader.end()); 323 functions.append(functionBody.begin(), functionBody.end()); 324 functionHeader.clear(); 325 functionBody.clear(); 326 327 return success(); 328 } 329 330 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { 331 SmallVector<uint32_t, 4> operands; 332 SmallVector<StringRef, 2> elidedAttrs; 333 uint32_t resultID = 0; 334 uint32_t resultTypeID = 0; 335 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { 336 return failure(); 337 } 338 operands.push_back(resultTypeID); 339 resultID = getNextID(); 340 valueIDMap[op.getResult()] = resultID; 341 operands.push_back(resultID); 342 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); 343 if (attr) { 344 operands.push_back( 345 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue())); 346 } 347 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 348 for (auto arg : op.getODSOperands(0)) { 349 auto argID = getValueID(arg); 350 if (!argID) { 351 return emitError(op.getLoc(), "operand 0 has a use before def"); 352 } 353 operands.push_back(argID); 354 } 355 if (failed(emitDebugLine(functionHeader, op.getLoc()))) 356 return failure(); 357 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); 358 for (auto attr : op->getAttrs()) { 359 if (llvm::any_of(elidedAttrs, [&](StringRef elided) { 360 return attr.getName() == elided; 361 })) { 362 continue; 363 } 364 if (failed(processDecoration(op.getLoc(), resultID, attr))) { 365 return failure(); 366 } 367 } 368 return success(); 369 } 370 371 LogicalResult 372 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { 373 // Get TypeID. 374 uint32_t resultTypeID = 0; 375 SmallVector<StringRef, 4> elidedAttrs; 376 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) { 377 return failure(); 378 } 379 380 elidedAttrs.push_back("type"); 381 SmallVector<uint32_t, 4> operands; 382 operands.push_back(resultTypeID); 383 auto resultID = getNextID(); 384 385 // Encode the name. 386 auto varName = varOp.getSymName(); 387 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 388 if (failed(processName(resultID, varName))) { 389 return failure(); 390 } 391 globalVarIDMap[varName] = resultID; 392 operands.push_back(resultID); 393 394 // Encode StorageClass. 395 operands.push_back(static_cast<uint32_t>(varOp.storageClass())); 396 397 // Encode initialization. 398 StringRef initAttrName = varOp.getInitializerAttrName().getValue(); 399 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) { 400 uint32_t initializerID = 0; 401 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName); 402 Operation *initOp = SymbolTable::lookupNearestSymbolFrom( 403 varOp->getParentOp(), initRef.getAttr()); 404 405 // Check if initializer is GlobalVariable or SpecConstant* cases. 406 if (isa<spirv::GlobalVariableOp>(initOp)) 407 initializerID = getVariableID(*initSymbolName); 408 else 409 initializerID = getSpecConstID(*initSymbolName); 410 411 if (!initializerID) 412 return emitError(varOp.getLoc(), 413 "invalid usage of undefined variable as initializer"); 414 415 operands.push_back(initializerID); 416 elidedAttrs.push_back(initAttrName); 417 } 418 419 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc()))) 420 return failure(); 421 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands); 422 elidedAttrs.push_back(initAttrName); 423 424 // Encode decorations. 425 for (auto attr : varOp->getAttrs()) { 426 if (llvm::any_of(elidedAttrs, [&](StringRef elided) { 427 return attr.getName() == elided; 428 })) { 429 continue; 430 } 431 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { 432 return failure(); 433 } 434 } 435 return success(); 436 } 437 438 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { 439 // Assign <id>s to all blocks so that branches inside the SelectionOp can 440 // resolve properly. 441 auto &body = selectionOp.getBody(); 442 for (Block &block : body) 443 getOrCreateBlockID(&block); 444 445 auto *headerBlock = selectionOp.getHeaderBlock(); 446 auto *mergeBlock = selectionOp.getMergeBlock(); 447 auto headerID = getBlockID(headerBlock); 448 auto mergeID = getBlockID(mergeBlock); 449 auto loc = selectionOp.getLoc(); 450 451 // This SelectionOp is in some MLIR block with preceding and following ops. In 452 // the binary format, it should reside in separate SPIR-V blocks from its 453 // preceding and following ops. So we need to emit unconditional branches to 454 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal 455 // flow afterwards. 456 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 457 458 // Emit the selection header block, which dominates all other blocks, first. 459 // We need to emit an OpSelectionMerge instruction before the selection header 460 // block's terminator. 461 auto emitSelectionMerge = [&]() { 462 if (failed(emitDebugLine(functionBody, loc))) 463 return failure(); 464 lastProcessedWasMergeInst = true; 465 encodeInstructionInto( 466 functionBody, spirv::Opcode::OpSelectionMerge, 467 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())}); 468 return success(); 469 }; 470 if (failed( 471 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge))) 472 return failure(); 473 474 // Process all blocks with a depth-first visitor starting from the header 475 // block. The selection header block and merge block are skipped by this 476 // visitor. 477 if (failed(visitInPrettyBlockOrder( 478 headerBlock, [&](Block *block) { return processBlock(block); }, 479 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) 480 return failure(); 481 482 // There is nothing to do for the merge block in the selection, which just 483 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel 484 // instruction to start a new SPIR-V block for ops following this SelectionOp. 485 // The block should use the <id> for the merge block. 486 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 487 LLVM_DEBUG(llvm::dbgs() << "done merge "); 488 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); 489 LLVM_DEBUG(llvm::dbgs() << "\n"); 490 return success(); 491 } 492 493 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { 494 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve 495 // properly. We don't need to assign for the entry block, which is just for 496 // satisfying MLIR region's structural requirement. 497 auto &body = loopOp.getBody(); 498 for (Block &block : llvm::drop_begin(body)) 499 getOrCreateBlockID(&block); 500 501 auto *headerBlock = loopOp.getHeaderBlock(); 502 auto *continueBlock = loopOp.getContinueBlock(); 503 auto *mergeBlock = loopOp.getMergeBlock(); 504 auto headerID = getBlockID(headerBlock); 505 auto continueID = getBlockID(continueBlock); 506 auto mergeID = getBlockID(mergeBlock); 507 auto loc = loopOp.getLoc(); 508 509 // This LoopOp is in some MLIR block with preceding and following ops. In the 510 // binary format, it should reside in separate SPIR-V blocks from its 511 // preceding and following ops. So we need to emit unconditional branches to 512 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow 513 // afterwards. 514 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 515 516 // LoopOp's entry block is just there for satisfying MLIR's structural 517 // requirements so we omit it and start serialization from the loop header 518 // block. 519 520 // Emit the loop header block, which dominates all other blocks, first. We 521 // need to emit an OpLoopMerge instruction before the loop header block's 522 // terminator. 523 auto emitLoopMerge = [&]() { 524 if (failed(emitDebugLine(functionBody, loc))) 525 return failure(); 526 lastProcessedWasMergeInst = true; 527 encodeInstructionInto( 528 functionBody, spirv::Opcode::OpLoopMerge, 529 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())}); 530 return success(); 531 }; 532 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) 533 return failure(); 534 535 // Process all blocks with a depth-first visitor starting from the header 536 // block. The loop header block, loop continue block, and loop merge block are 537 // skipped by this visitor and handled later in this function. 538 if (failed(visitInPrettyBlockOrder( 539 headerBlock, [&](Block *block) { return processBlock(block); }, 540 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) 541 return failure(); 542 543 // We have handled all other blocks. Now get to the loop continue block. 544 if (failed(processBlock(continueBlock))) 545 return failure(); 546 547 // There is nothing to do for the merge block in the loop, which just contains 548 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction 549 // to start a new SPIR-V block for ops following this LoopOp. The block should 550 // use the <id> for the merge block. 551 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 552 LLVM_DEBUG(llvm::dbgs() << "done merge "); 553 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); 554 LLVM_DEBUG(llvm::dbgs() << "\n"); 555 return success(); 556 } 557 558 LogicalResult Serializer::processBranchConditionalOp( 559 spirv::BranchConditionalOp condBranchOp) { 560 auto conditionID = getValueID(condBranchOp.getCondition()); 561 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); 562 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); 563 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; 564 565 if (auto weights = condBranchOp.getBranchWeights()) { 566 for (auto val : weights->getValue()) 567 arguments.push_back(cast<IntegerAttr>(val).getInt()); 568 } 569 570 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc()))) 571 return failure(); 572 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, 573 arguments); 574 return success(); 575 } 576 577 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { 578 if (failed(emitDebugLine(functionBody, branchOp.getLoc()))) 579 return failure(); 580 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 581 {getOrCreateBlockID(branchOp.getTarget())}); 582 return success(); 583 } 584 585 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { 586 auto varName = addressOfOp.getVariable(); 587 auto variableID = getVariableID(varName); 588 if (!variableID) { 589 return addressOfOp.emitError("unknown result <id> for variable ") 590 << varName; 591 } 592 valueIDMap[addressOfOp.getPointer()] = variableID; 593 return success(); 594 } 595 596 LogicalResult 597 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { 598 auto constName = referenceOfOp.getSpecConst(); 599 auto constID = getSpecConstID(constName); 600 if (!constID) { 601 return referenceOfOp.emitError( 602 "unknown result <id> for specialization constant ") 603 << constName; 604 } 605 valueIDMap[referenceOfOp.getReference()] = constID; 606 return success(); 607 } 608 609 template <> 610 LogicalResult 611 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { 612 SmallVector<uint32_t, 4> operands; 613 // Add the ExecutionModel. 614 operands.push_back(static_cast<uint32_t>(op.getExecutionModel())); 615 // Add the function <id>. 616 auto funcID = getFunctionID(op.getFn()); 617 if (!funcID) { 618 return op.emitError("missing <id> for function ") 619 << op.getFn() 620 << "; function needs to be defined before spirv.EntryPoint is " 621 "serialized"; 622 } 623 operands.push_back(funcID); 624 // Add the name of the function. 625 spirv::encodeStringLiteralInto(operands, op.getFn()); 626 627 // Add the interface values. 628 if (auto interface = op.getInterface()) { 629 for (auto var : interface.getValue()) { 630 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue()); 631 if (!id) { 632 return op.emitError( 633 "referencing undefined global variable." 634 "spirv.EntryPoint is at the end of spirv.module. All " 635 "referenced variables should already be defined"); 636 } 637 operands.push_back(id); 638 } 639 } 640 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); 641 return success(); 642 } 643 644 template <> 645 LogicalResult 646 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { 647 SmallVector<uint32_t, 4> operands; 648 // Add the function <id>. 649 auto funcID = getFunctionID(op.getFn()); 650 if (!funcID) { 651 return op.emitError("missing <id> for function ") 652 << op.getFn() 653 << "; function needs to be serialized before ExecutionModeOp is " 654 "serialized"; 655 } 656 operands.push_back(funcID); 657 // Add the ExecutionMode. 658 operands.push_back(static_cast<uint32_t>(op.getExecutionMode())); 659 660 // Serialize values if any. 661 auto values = op.getValues(); 662 if (values) { 663 for (auto &intVal : values.getValue()) { 664 operands.push_back(static_cast<uint32_t>( 665 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue())); 666 } 667 } 668 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, 669 operands); 670 return success(); 671 } 672 673 template <> 674 LogicalResult 675 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { 676 auto funcName = op.getCallee(); 677 uint32_t resTypeID = 0; 678 679 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); 680 if (failed(processType(op.getLoc(), resultTy, resTypeID))) 681 return failure(); 682 683 auto funcID = getOrCreateFunctionID(funcName); 684 auto funcCallID = getNextID(); 685 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; 686 687 for (auto value : op.getArguments()) { 688 auto valueID = getValueID(value); 689 assert(valueID && "cannot find a value for spirv.FunctionCall"); 690 operands.push_back(valueID); 691 } 692 693 if (!isa<NoneType>(resultTy)) 694 valueIDMap[op.getResult(0)] = funcCallID; 695 696 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); 697 return success(); 698 } 699 700 template <> 701 LogicalResult 702 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { 703 SmallVector<uint32_t, 4> operands; 704 SmallVector<StringRef, 2> elidedAttrs; 705 706 for (Value operand : op->getOperands()) { 707 auto id = getValueID(operand); 708 assert(id && "use before def!"); 709 operands.push_back(id); 710 } 711 712 StringAttr memoryAccess = op.getMemoryAccessAttrName(); 713 if (auto attr = op->getAttr(memoryAccess)) { 714 operands.push_back( 715 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); 716 } 717 718 elidedAttrs.push_back(memoryAccess.strref()); 719 720 StringAttr alignment = op.getAlignmentAttrName(); 721 if (auto attr = op->getAttr(alignment)) { 722 operands.push_back(static_cast<uint32_t>( 723 cast<IntegerAttr>(attr).getValue().getZExtValue())); 724 } 725 726 elidedAttrs.push_back(alignment.strref()); 727 728 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); 729 if (auto attr = op->getAttr(sourceMemoryAccess)) { 730 operands.push_back( 731 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); 732 } 733 734 elidedAttrs.push_back(sourceMemoryAccess.strref()); 735 736 StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); 737 if (auto attr = op->getAttr(sourceAlignment)) { 738 operands.push_back(static_cast<uint32_t>( 739 cast<IntegerAttr>(attr).getValue().getZExtValue())); 740 } 741 742 elidedAttrs.push_back(sourceAlignment.strref()); 743 if (failed(emitDebugLine(functionBody, op.getLoc()))) 744 return failure(); 745 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); 746 747 return success(); 748 } 749 template <> 750 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>( 751 spirv::GenericCastToPtrExplicitOp op) { 752 SmallVector<uint32_t, 4> operands; 753 Type resultTy; 754 Location loc = op->getLoc(); 755 uint32_t resultTypeID = 0; 756 uint32_t resultID = 0; 757 resultTy = op->getResult(0).getType(); 758 if (failed(processType(loc, resultTy, resultTypeID))) 759 return failure(); 760 operands.push_back(resultTypeID); 761 762 resultID = getNextID(); 763 operands.push_back(resultID); 764 valueIDMap[op->getResult(0)] = resultID; 765 766 for (Value operand : op->getOperands()) 767 operands.push_back(getValueID(operand)); 768 spirv::StorageClass resultStorage = 769 cast<spirv::PointerType>(resultTy).getStorageClass(); 770 operands.push_back(static_cast<uint32_t>(resultStorage)); 771 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, 772 operands); 773 return success(); 774 } 775 776 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and 777 // various Serializer::processOp<...>() specializations. 778 #define GET_SERIALIZATION_FNS 779 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 780 781 } // namespace spirv 782 } // namespace mlir 783