1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // SPIRVSerializationGen generates common utility functions for SPIR-V 10 // serialization. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/CodeGenHelpers.h" 16 #include "mlir/TableGen/Format.h" 17 #include "mlir/TableGen/GenInfo.h" 18 #include "mlir/TableGen/Operator.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringMap.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/ADT/StringSet.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Record.h" 30 #include "llvm/TableGen/TableGenBackend.h" 31 32 #include <list> 33 #include <optional> 34 35 using llvm::ArrayRef; 36 using llvm::cast; 37 using llvm::formatv; 38 using llvm::isa; 39 using llvm::raw_ostream; 40 using llvm::raw_string_ostream; 41 using llvm::Record; 42 using llvm::RecordKeeper; 43 using llvm::SmallVector; 44 using llvm::SMLoc; 45 using llvm::StringMap; 46 using llvm::StringRef; 47 using mlir::tblgen::Attribute; 48 using mlir::tblgen::EnumAttr; 49 using mlir::tblgen::EnumAttrCase; 50 using mlir::tblgen::NamedAttribute; 51 using mlir::tblgen::NamedTypeConstraint; 52 using mlir::tblgen::NamespaceEmitter; 53 using mlir::tblgen::Operator; 54 55 //===----------------------------------------------------------------------===// 56 // Availability Wrapper Class 57 //===----------------------------------------------------------------------===// 58 59 namespace { 60 // Wrapper class with helper methods for accessing availability defined in 61 // TableGen. 62 class Availability { 63 public: 64 explicit Availability(const Record *def); 65 66 // Returns the name of the direct TableGen class for this availability 67 // instance. 68 StringRef getClass() const; 69 70 // Returns the generated C++ interface's class namespace. 71 StringRef getInterfaceClassNamespace() const; 72 73 // Returns the generated C++ interface's class name. 74 StringRef getInterfaceClassName() const; 75 76 // Returns the generated C++ interface's description. 77 StringRef getInterfaceDescription() const; 78 79 // Returns the name of the query function insided the generated C++ interface. 80 StringRef getQueryFnName() const; 81 82 // Returns the return type of the query function insided the generated C++ 83 // interface. 84 StringRef getQueryFnRetType() const; 85 86 // Returns the code for merging availability requirements. 87 StringRef getMergeActionCode() const; 88 89 // Returns the initializer expression for initializing the final availability 90 // requirements. 91 StringRef getMergeInitializer() const; 92 93 // Returns the C++ type for an availability instance. 94 StringRef getMergeInstanceType() const; 95 96 // Returns the C++ statements for preparing availability instance. 97 StringRef getMergeInstancePreparation() const; 98 99 // Returns the concrete availability instance carried in this case. 100 StringRef getMergeInstance() const; 101 102 // Returns the underlying LLVM TableGen Record. 103 const Record *getDef() const { return def; } 104 105 private: 106 // The TableGen definition of this availability. 107 const Record *def; 108 }; 109 } // namespace 110 111 Availability::Availability(const Record *def) : def(def) { 112 assert(def->isSubClassOf("Availability") && 113 "must be subclass of TableGen 'Availability' class"); 114 } 115 116 StringRef Availability::getClass() const { 117 SmallVector<const Record *, 1> parentClass; 118 def->getDirectSuperClasses(parentClass); 119 if (parentClass.size() != 1) { 120 PrintFatalError(def->getLoc(), 121 "expected to only have one direct superclass"); 122 } 123 return parentClass.front()->getName(); 124 } 125 126 StringRef Availability::getInterfaceClassNamespace() const { 127 return def->getValueAsString("cppNamespace"); 128 } 129 130 StringRef Availability::getInterfaceClassName() const { 131 return def->getValueAsString("interfaceName"); 132 } 133 134 StringRef Availability::getInterfaceDescription() const { 135 return def->getValueAsString("interfaceDescription"); 136 } 137 138 StringRef Availability::getQueryFnRetType() const { 139 return def->getValueAsString("queryFnRetType"); 140 } 141 142 StringRef Availability::getQueryFnName() const { 143 return def->getValueAsString("queryFnName"); 144 } 145 146 StringRef Availability::getMergeActionCode() const { 147 return def->getValueAsString("mergeAction"); 148 } 149 150 StringRef Availability::getMergeInitializer() const { 151 return def->getValueAsString("initializer"); 152 } 153 154 StringRef Availability::getMergeInstanceType() const { 155 return def->getValueAsString("instanceType"); 156 } 157 158 StringRef Availability::getMergeInstancePreparation() const { 159 return def->getValueAsString("instancePreparation"); 160 } 161 162 StringRef Availability::getMergeInstance() const { 163 return def->getValueAsString("instance"); 164 } 165 166 // Returns the availability spec of the given `def`. 167 std::vector<Availability> getAvailabilities(const Record &def) { 168 std::vector<Availability> availabilities; 169 170 if (def.getValue("availability")) { 171 std::vector<const Record *> availDefs = 172 def.getValueAsListOfDefs("availability"); 173 availabilities.reserve(availDefs.size()); 174 for (const Record *avail : availDefs) 175 availabilities.emplace_back(avail); 176 } 177 178 return availabilities; 179 } 180 181 //===----------------------------------------------------------------------===// 182 // Availability Interface Definitions AutoGen 183 //===----------------------------------------------------------------------===// 184 185 static void emitInterfaceDef(const Availability &availability, 186 raw_ostream &os) { 187 188 os << availability.getQueryFnRetType() << " "; 189 190 StringRef cppNamespace = availability.getInterfaceClassNamespace(); 191 cppNamespace.consume_front("::"); 192 if (!cppNamespace.empty()) 193 os << cppNamespace << "::"; 194 195 StringRef methodName = availability.getQueryFnName(); 196 os << availability.getInterfaceClassName() << "::" << methodName << "() {\n" 197 << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" 198 << "}\n"; 199 } 200 201 static bool emitInterfaceDefs(const RecordKeeper &records, raw_ostream &os) { 202 llvm::emitSourceFileHeader("Availability Interface Definitions", os, records); 203 204 auto defs = records.getAllDerivedDefinitions("Availability"); 205 SmallVector<const Record *, 1> handledClasses; 206 for (const Record *def : defs) { 207 SmallVector<const Record *, 1> parent; 208 def->getDirectSuperClasses(parent); 209 if (parent.size() != 1) { 210 PrintFatalError(def->getLoc(), 211 "expected to only have one direct superclass"); 212 } 213 if (llvm::is_contained(handledClasses, parent.front())) 214 continue; 215 216 Availability availability(def); 217 emitInterfaceDef(availability, os); 218 handledClasses.push_back(parent.front()); 219 } 220 return false; 221 } 222 223 //===----------------------------------------------------------------------===// 224 // Availability Interface Declarations AutoGen 225 //===----------------------------------------------------------------------===// 226 227 static void emitConceptDecl(const Availability &availability, raw_ostream &os) { 228 os << " class Concept {\n" 229 << " public:\n" 230 << " virtual ~Concept() = default;\n" 231 << " virtual " << availability.getQueryFnRetType() << " " 232 << availability.getQueryFnName() 233 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" 234 << " };\n"; 235 } 236 237 static void emitModelDecl(const Availability &availability, raw_ostream &os) { 238 for (const char *modelClass : {"Model", "FallbackModel"}) { 239 os << " template<typename ConcreteOp>\n"; 240 os << " class " << modelClass << " : public Concept {\n" 241 << " public:\n" 242 << " using Interface = " << availability.getInterfaceClassName() 243 << ";\n" 244 << " " << availability.getQueryFnRetType() << " " 245 << availability.getQueryFnName() 246 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" 247 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n" 248 << " (void)op;\n" 249 // Forward to the method on the concrete operation type. 250 << " return op." << availability.getQueryFnName() << "();\n" 251 << " }\n" 252 << " };\n"; 253 } 254 os << " template<typename ConcreteModel, typename ConcreteOp>\n"; 255 os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n"; 256 } 257 258 static void emitInterfaceDecl(const Availability &availability, 259 raw_ostream &os) { 260 StringRef interfaceName = availability.getInterfaceClassName(); 261 std::string interfaceTraitsName = 262 std::string(formatv("{0}Traits", interfaceName)); 263 264 StringRef cppNamespace = availability.getInterfaceClassNamespace(); 265 NamespaceEmitter nsEmitter(os, cppNamespace); 266 os << "class " << interfaceName << ";\n\n"; 267 268 // Emit the traits struct containing the concept and model declarations. 269 os << "namespace detail {\n" 270 << "struct " << interfaceTraitsName << " {\n"; 271 emitConceptDecl(availability, os); 272 os << '\n'; 273 emitModelDecl(availability, os); 274 os << "};\n} // namespace detail\n\n"; 275 276 // Emit the main interface class declaration. 277 os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n"; 278 os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" 279 "public:\n" 280 " using OpInterface<{1}, detail::{2}>::OpInterface;\n", 281 interfaceName, interfaceName, interfaceTraitsName); 282 283 // Emit query function declaration. 284 os << " " << availability.getQueryFnRetType() << " " 285 << availability.getQueryFnName() << "();\n"; 286 os << "};\n\n"; 287 } 288 289 static bool emitInterfaceDecls(const RecordKeeper &records, raw_ostream &os) { 290 llvm::emitSourceFileHeader("Availability Interface Declarations", os, 291 records); 292 293 auto defs = records.getAllDerivedDefinitions("Availability"); 294 SmallVector<const Record *, 4> handledClasses; 295 for (const Record *def : defs) { 296 SmallVector<const Record *, 1> parent; 297 def->getDirectSuperClasses(parent); 298 if (parent.size() != 1) { 299 PrintFatalError(def->getLoc(), 300 "expected to only have one direct superclass"); 301 } 302 if (llvm::is_contained(handledClasses, parent.front())) 303 continue; 304 305 Availability avail(def); 306 emitInterfaceDecl(avail, os); 307 handledClasses.push_back(parent.front()); 308 } 309 return false; 310 } 311 312 //===----------------------------------------------------------------------===// 313 // Availability Interface Hook Registration 314 //===----------------------------------------------------------------------===// 315 316 // Registers the operation interface generator to mlir-tblgen. 317 static mlir::GenRegistration 318 genInterfaceDecls("gen-avail-interface-decls", 319 "Generate availability interface declarations", 320 [](const RecordKeeper &records, raw_ostream &os) { 321 return emitInterfaceDecls(records, os); 322 }); 323 324 // Registers the operation interface generator to mlir-tblgen. 325 static mlir::GenRegistration 326 genInterfaceDefs("gen-avail-interface-defs", 327 "Generate op interface definitions", 328 [](const RecordKeeper &records, raw_ostream &os) { 329 return emitInterfaceDefs(records, os); 330 }); 331 332 //===----------------------------------------------------------------------===// 333 // Enum Availability Query AutoGen 334 //===----------------------------------------------------------------------===// 335 336 static void emitAvailabilityQueryForIntEnum(const Record &enumDef, 337 raw_ostream &os) { 338 EnumAttr enumAttr(enumDef); 339 StringRef enumName = enumAttr.getEnumClassName(); 340 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases(); 341 342 // Mapping from availability class name to (enumerant, availability 343 // specification) pairs. 344 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>> 345 classCaseMap; 346 347 // Place all availability specifications to their corresponding 348 // availability classes. 349 for (const EnumAttrCase &enumerant : enumerants) 350 for (const Availability &avail : getAvailabilities(enumerant.getDef())) 351 classCaseMap[avail.getClass()].push_back({enumerant, avail}); 352 353 for (const auto &classCasePair : classCaseMap) { 354 Availability avail = classCasePair.getValue().front().second; 355 356 os << formatv("std::optional<{0}> {1}({2} value) {{\n", 357 avail.getMergeInstanceType(), avail.getQueryFnName(), 358 enumName); 359 360 os << " switch (value) {\n"; 361 for (const auto &caseSpecPair : classCasePair.getValue()) { 362 EnumAttrCase enumerant = caseSpecPair.first; 363 Availability avail = caseSpecPair.second; 364 os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, 365 enumerant.getSymbol(), avail.getMergeInstancePreparation(), 366 avail.getMergeInstanceType(), avail.getMergeInstance()); 367 } 368 // Only emit default if uncovered cases. 369 if (classCasePair.getValue().size() < enumAttr.getAllCases().size()) 370 os << " default: break;\n"; 371 os << " }\n" 372 << " return std::nullopt;\n" 373 << "}\n"; 374 } 375 } 376 377 static void emitAvailabilityQueryForBitEnum(const Record &enumDef, 378 raw_ostream &os) { 379 EnumAttr enumAttr(enumDef); 380 StringRef enumName = enumAttr.getEnumClassName(); 381 std::string underlyingType = std::string(enumAttr.getUnderlyingType()); 382 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases(); 383 384 // Mapping from availability class name to (enumerant, availability 385 // specification) pairs. 386 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>> 387 classCaseMap; 388 389 // Place all availability specifications to their corresponding 390 // availability classes. 391 for (const EnumAttrCase &enumerant : enumerants) 392 for (const Availability &avail : getAvailabilities(enumerant.getDef())) 393 classCaseMap[avail.getClass()].push_back({enumerant, avail}); 394 395 for (const auto &classCasePair : classCaseMap) { 396 Availability avail = classCasePair.getValue().front().second; 397 398 os << formatv("std::optional<{0}> {1}({2} value) {{\n", 399 avail.getMergeInstanceType(), avail.getQueryFnName(), 400 enumName); 401 402 os << formatv( 403 " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" 404 " && \"cannot have more than one bit set\");\n", 405 underlyingType); 406 407 os << " switch (value) {\n"; 408 for (const auto &caseSpecPair : classCasePair.getValue()) { 409 EnumAttrCase enumerant = caseSpecPair.first; 410 Availability avail = caseSpecPair.second; 411 os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, 412 enumerant.getSymbol(), avail.getMergeInstancePreparation(), 413 avail.getMergeInstanceType(), avail.getMergeInstance()); 414 } 415 os << " default: break;\n"; 416 os << " }\n" 417 << " return std::nullopt;\n" 418 << "}\n"; 419 } 420 } 421 422 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { 423 EnumAttr enumAttr(enumDef); 424 StringRef enumName = enumAttr.getEnumClassName(); 425 StringRef cppNamespace = enumAttr.getCppNamespace(); 426 auto enumerants = enumAttr.getAllCases(); 427 428 llvm::SmallVector<StringRef, 2> namespaces; 429 llvm::SplitString(cppNamespace, namespaces, "::"); 430 431 for (auto ns : namespaces) 432 os << "namespace " << ns << " {\n"; 433 434 llvm::StringSet<> handledClasses; 435 436 // Place all availability specifications to their corresponding 437 // availability classes. 438 for (const EnumAttrCase &enumerant : enumerants) 439 for (const Availability &avail : getAvailabilities(enumerant.getDef())) { 440 StringRef className = avail.getClass(); 441 if (handledClasses.count(className)) 442 continue; 443 os << formatv("std::optional<{0}> {1}({2} value);\n", 444 avail.getMergeInstanceType(), avail.getQueryFnName(), 445 enumName); 446 handledClasses.insert(className); 447 } 448 449 for (auto ns : llvm::reverse(namespaces)) 450 os << "} // namespace " << ns << "\n"; 451 } 452 453 static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { 454 llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os, 455 records); 456 457 auto defs = records.getAllDerivedDefinitions("EnumAttrInfo"); 458 for (const auto *def : defs) 459 emitEnumDecl(*def, os); 460 461 return false; 462 } 463 464 static void emitEnumDef(const Record &enumDef, raw_ostream &os) { 465 EnumAttr enumAttr(enumDef); 466 StringRef cppNamespace = enumAttr.getCppNamespace(); 467 468 llvm::SmallVector<StringRef, 2> namespaces; 469 llvm::SplitString(cppNamespace, namespaces, "::"); 470 471 for (auto ns : namespaces) 472 os << "namespace " << ns << " {\n"; 473 474 if (enumAttr.isBitEnum()) { 475 emitAvailabilityQueryForBitEnum(enumDef, os); 476 } else { 477 emitAvailabilityQueryForIntEnum(enumDef, os); 478 } 479 480 for (auto ns : llvm::reverse(namespaces)) 481 os << "} // namespace " << ns << "\n"; 482 os << "\n"; 483 } 484 485 static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { 486 llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os, 487 records); 488 489 auto defs = records.getAllDerivedDefinitions("EnumAttrInfo"); 490 for (const auto *def : defs) 491 emitEnumDef(*def, os); 492 493 return false; 494 } 495 496 //===----------------------------------------------------------------------===// 497 // Enum Availability Query Hook Registration 498 //===----------------------------------------------------------------------===// 499 500 // Registers the enum utility generator to mlir-tblgen. 501 static mlir::GenRegistration 502 genEnumDecls("gen-spirv-enum-avail-decls", 503 "Generate SPIR-V enum availability declarations", 504 [](const RecordKeeper &records, raw_ostream &os) { 505 return emitEnumDecls(records, os); 506 }); 507 508 // Registers the enum utility generator to mlir-tblgen. 509 static mlir::GenRegistration 510 genEnumDefs("gen-spirv-enum-avail-defs", 511 "Generate SPIR-V enum availability definitions", 512 [](const RecordKeeper &records, raw_ostream &os) { 513 return emitEnumDefs(records, os); 514 }); 515 516 //===----------------------------------------------------------------------===// 517 // Serialization AutoGen 518 //===----------------------------------------------------------------------===// 519 520 // These enums are encoded as <id> to constant values in SPIR-V blob, but we 521 // directly use the constant value as attribute in SPIR-V dialect. So need 522 // to handle them separately from normal enum attributes. 523 constexpr llvm::StringLiteral constantIdEnumAttrs[] = { 524 "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr", 525 "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr", 526 "SPIRV_MatrixLayoutAttr"}; 527 528 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The 529 /// generates code extracts the attribute with name `attrName` from 530 /// `operandList` of `op`. 531 static void emitAttributeSerialization(const Attribute &attr, 532 ArrayRef<SMLoc> loc, StringRef tabs, 533 StringRef opVar, StringRef operandList, 534 StringRef attrName, raw_ostream &os) { 535 os << tabs 536 << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); 537 if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) { 538 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); 539 os << tabs 540 << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), " 541 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>(" 542 "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n", 543 operandList, opVar, baseEnum.getCppNamespace(), 544 baseEnum.getEnumClassName()); 545 } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") || 546 attr.isSubClassOf("SPIRV_I32EnumAttr")) { 547 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); 548 os << tabs 549 << formatv(" {0}.push_back(static_cast<uint32_t>(" 550 "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n", 551 operandList, baseEnum.getCppNamespace(), 552 baseEnum.getEnumClassName()); 553 } else if (attr.getAttrDefName() == "I32ArrayAttr") { 554 // Serialize all the elements of the array 555 os << tabs << " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n"; 556 os << tabs 557 << formatv(" {0}.push_back(static_cast<uint32_t>(" 558 "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())" 559 ");\n", 560 operandList); 561 os << tabs << " }\n"; 562 } else if (attr.getAttrDefName() == "I32Attr") { 563 os << tabs 564 << formatv( 565 " {0}.push_back(static_cast<uint32_t>(" 566 "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n", 567 operandList); 568 } else if (attr.isEnumAttr() || attr.isTypeAttr()) { 569 // It may be the first time this type appears in the IR, so we need to 570 // process it. 571 StringRef attrTypeID = "attrTypeID"; 572 os << tabs << formatv(" uint32_t {0} = 0;\n", attrTypeID); 573 os << tabs 574 << formatv(" if (failed(processType({0}.getLoc(), " 575 "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n", 576 opVar, attrTypeID); 577 os << tabs << " return failure();\n"; 578 os << tabs << " }\n"; 579 os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList); 580 } else { 581 PrintFatalError( 582 loc, 583 llvm::Twine( 584 "unhandled attribute type in SPIR-V serialization generation : '") + 585 attr.getAttrDefName() + llvm::Twine("'")); 586 } 587 os << tabs << "}\n"; 588 } 589 590 /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The 591 /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the 592 /// attributes. The `operands` vector is updated appropriately. `elidedAttrs` 593 /// updated as well to include the serialized attributes. 594 static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc, 595 StringRef tabs, StringRef opVar, 596 StringRef operands, StringRef elidedAttrs, 597 raw_ostream &os) { 598 using mlir::tblgen::Argument; 599 600 // SPIR-V ops can mix operands and attributes in the definition. These 601 // operands and attributes are serialized in the exact order of the definition 602 // to match SPIR-V binary format requirements. It can cause excessive 603 // generated code bloat because we are emitting code to handle each 604 // operand/attribute separately. So here we probe first to check whether all 605 // the operands are ahead of attributes. Then we can serialize all operands 606 // together. 607 608 // Whether all operands are ahead of all attributes in the op's spec. 609 bool areOperandsAheadOfAttrs = true; 610 // Find the first attribute. 611 const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) { 612 return isa<NamedAttribute *>(arg); 613 }); 614 // Check whether all following arguments are attributes. 615 for (const Argument *ie = op.arg_end(); it != ie; ++it) { 616 if (!isa<NamedAttribute *>(*it)) { 617 areOperandsAheadOfAttrs = false; 618 break; 619 } 620 } 621 622 // Serialize all operands together. 623 if (areOperandsAheadOfAttrs) { 624 if (op.getNumOperands() != 0) { 625 os << tabs 626 << formatv("for (Value operand : {0}->getOperands()) {{\n", opVar); 627 os << tabs << " auto id = getValueID(operand);\n"; 628 os << tabs << " assert(id && \"use before def!\");\n"; 629 os << tabs << formatv(" {0}.push_back(id);\n", operands); 630 os << tabs << "}\n"; 631 } 632 for (const NamedAttribute &attr : op.getAttributes()) { 633 emitAttributeSerialization( 634 (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc, 635 tabs, opVar, operands, attr.name, os); 636 os << tabs 637 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr.name); 638 } 639 return; 640 } 641 642 // Serialize operands separately. 643 auto operandNum = 0; 644 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { 645 auto argument = op.getArg(i); 646 os << tabs << "{\n"; 647 if (isa<NamedTypeConstraint *>(argument)) { 648 os << tabs 649 << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar, 650 operandNum); 651 os << tabs << " auto argID = getValueID(arg);\n"; 652 os << tabs << " if (!argID) {\n"; 653 os << tabs 654 << formatv(" return emitError({0}.getLoc(), " 655 "\"operand #{1} has a use before def\");\n", 656 opVar, operandNum); 657 os << tabs << " }\n"; 658 os << tabs << formatv(" {0}.push_back(argID);\n", operands); 659 os << " }\n"; 660 operandNum++; 661 } else { 662 NamedAttribute *attr = cast<NamedAttribute *>(argument); 663 auto newtabs = tabs.str() + " "; 664 emitAttributeSerialization( 665 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), 666 loc, newtabs, opVar, operands, attr->name, os); 667 os << newtabs 668 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name); 669 } 670 os << tabs << "}\n"; 671 } 672 } 673 674 /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The 675 /// generated gets the ID for the type of the result (if any), the SSA-ID of 676 /// the result and updates `resultID` with the SSA-ID. 677 static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc, 678 StringRef tabs, StringRef opVar, 679 StringRef operands, StringRef resultID, 680 raw_ostream &os) { 681 if (op.getNumResults() == 1) { 682 StringRef resultTypeID("resultTypeID"); 683 os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID); 684 os << tabs 685 << formatv( 686 "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n", 687 opVar, resultTypeID); 688 os << tabs << " return failure();\n"; 689 os << tabs << "}\n"; 690 os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID); 691 // Create an SSA result <id> for the op 692 os << tabs << formatv("{0} = getNextID();\n", resultID); 693 os << tabs 694 << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID); 695 os << tabs << formatv("{0}.push_back({1});\n", operands, resultID); 696 } else if (op.getNumResults() != 0) { 697 PrintFatalError(loc, "SPIR-V ops can only have zero or one result"); 698 } 699 } 700 701 /// Generates code to serialize attributes of SPIRV_Op `op` that become 702 /// decorations on the `resultID` of the serialized operation `opVar` in the 703 /// SPIR-V binary. 704 static void emitDecorationSerialization(const Operator &op, StringRef tabs, 705 StringRef opVar, StringRef elidedAttrs, 706 StringRef resultID, raw_ostream &os) { 707 if (op.getNumResults() == 1) { 708 // All non-argument attributes translated into OpDecorate instruction 709 os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar); 710 os << tabs 711 << formatv(" if (llvm::is_contained({0}, attr.getName())) {{", 712 elidedAttrs); 713 os << tabs << " continue;\n"; 714 os << tabs << " }\n"; 715 os << tabs 716 << formatv( 717 " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n", 718 opVar, resultID); 719 os << tabs << " return failure();\n"; 720 os << tabs << " }\n"; 721 os << tabs << "}\n"; 722 } 723 } 724 725 /// Generates code to serialize an SPIRV_Op `op` into `os`. 726 static void emitSerializationFunction(const Record *attrClass, 727 const Record *record, const Operator &op, 728 raw_ostream &os) { 729 // If the record has 'autogenSerialization' set to 0, nothing to do 730 if (!record->getValueAsBit("autogenSerialization")) 731 return; 732 733 StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"), 734 resultID("resultID"); 735 736 os << formatv( 737 "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n", 738 op.getQualCppClassName(), opVar); 739 740 // Special case for ops without attributes in TableGen definitions 741 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { 742 std::string extInstSet; 743 std::string opcode; 744 if (record->isSubClassOf("SPIRV_ExtInstOp")) { 745 extInstSet = 746 formatv("\"{0}\"", record->getValueAsString("extendedInstSetName")); 747 opcode = std::to_string(record->getValueAsInt("extendedInstOpcode")); 748 } else { 749 extInstSet = "\"\""; 750 opcode = formatv("static_cast<uint32_t>(spirv::Opcode::{0})", 751 record->getValueAsString("spirvOpName")); 752 } 753 754 os << formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n", 755 opVar, extInstSet, opcode); 756 return; 757 } 758 759 os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands); 760 os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs); 761 762 // Serialize result information. 763 if (op.getNumResults() == 1) { 764 os << formatv(" uint32_t {0} = 0;\n", resultID); 765 emitResultSerialization(op, record->getLoc(), " ", opVar, operands, 766 resultID, os); 767 } 768 769 // Process arguments. 770 emitArgumentSerialization(op, record->getLoc(), " ", opVar, operands, 771 elidedAttrs, os); 772 773 if (record->isSubClassOf("SPIRV_ExtInstOp")) { 774 os << formatv( 775 " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar, 776 record->getValueAsString("extendedInstSetName"), 777 record->getValueAsInt("extendedInstOpcode"), operands); 778 } else { 779 // Emit debug info. 780 os << formatv(" (void)emitDebugLine(functionBody, {0}.getLoc());\n", 781 opVar); 782 os << formatv(" (void)encodeInstructionInto(" 783 "functionBody, spirv::Opcode::{0}, {1});\n", 784 record->getValueAsString("spirvOpName"), operands); 785 } 786 787 // Process decorations. 788 emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os); 789 790 os << " return success();\n"; 791 os << "}\n\n"; 792 } 793 794 /// Generates the prologue for the function that dispatches the serialization of 795 /// the operation `opVar` based on its opcode. 796 static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) { 797 os << formatv( 798 "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " 799 "*{0}) {{\n", 800 opVar); 801 } 802 803 /// Generates the body of the dispatch function. This function generates the 804 /// check that if satisfied, will call the serialization function generated for 805 /// the `op`. 806 static void emitSerializationDispatch(const Operator &op, StringRef tabs, 807 StringRef opVar, raw_ostream &os) { 808 os << tabs 809 << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar); 810 os << tabs 811 << formatv(" return processOp(cast<{0}>({1}));\n", 812 op.getQualCppClassName(), opVar); 813 os << tabs << "}\n"; 814 } 815 816 /// Generates the epilogue for the function that dispatches the serialization of 817 /// the operation. 818 static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) { 819 os << formatv( 820 " return {0}->emitError(\"unhandled operation serialization\");\n", 821 opVar); 822 os << "}\n\n"; 823 } 824 825 /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The 826 /// generated code reads the `words` of the serialized instruction at 827 /// position `wordIndex` and adds the deserialized attribute into `attrList`. 828 static void emitAttributeDeserialization(const Attribute &attr, 829 ArrayRef<SMLoc> loc, StringRef tabs, 830 StringRef attrList, StringRef attrName, 831 StringRef words, StringRef wordIndex, 832 raw_ostream &os) { 833 if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) { 834 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); 835 os << tabs 836 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " 837 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>(" 838 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n", 839 attrList, attrName, baseEnum.getCppNamespace(), 840 baseEnum.getEnumClassName(), words, wordIndex); 841 } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") || 842 attr.isSubClassOf("SPIRV_I32EnumAttr")) { 843 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum")); 844 os << tabs 845 << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", " 846 "opBuilder.getAttr<{2}::{3}Attr>(" 847 "static_cast<{2}::{3}>({4}[{5}++]))));\n", 848 attrList, attrName, baseEnum.getCppNamespace(), 849 baseEnum.getEnumClassName(), words, wordIndex); 850 } else if (attr.getAttrDefName() == "I32ArrayAttr") { 851 os << tabs << "SmallVector<Attribute, 4> attrListElems;\n"; 852 os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words); 853 os << tabs 854 << formatv( 855 " " 856 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))" 857 ";\n", 858 words, wordIndex); 859 os << tabs << "}\n"; 860 os << tabs 861 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " 862 "opBuilder.getArrayAttr(attrListElems)));\n", 863 attrList, attrName); 864 } else if (attr.getAttrDefName() == "I32Attr") { 865 os << tabs 866 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " 867 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n", 868 attrList, attrName, words, wordIndex); 869 } else if (attr.isEnumAttr() || attr.isTypeAttr()) { 870 os << tabs 871 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " 872 "TypeAttr::get(getType({2}[{3}++]))));\n", 873 attrList, attrName, words, wordIndex); 874 } else { 875 PrintFatalError( 876 loc, llvm::Twine( 877 "unhandled attribute type in deserialization generation : '") + 878 attrName + llvm::Twine("'")); 879 } 880 } 881 882 /// Generates the code to deserialize the result of an SPIRV_Op `op` into 883 /// `os`. The generated code gets the type of the result specified at 884 /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1 885 /// and updates the `resultType` and `valueID` with the parsed type and SSA ID, 886 /// respectively. 887 static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc, 888 StringRef tabs, StringRef words, 889 StringRef wordIndex, 890 StringRef resultTypes, StringRef valueID, 891 raw_ostream &os) { 892 // Deserialize result information if it exists 893 if (op.getNumResults() == 1) { 894 os << tabs << "{\n"; 895 os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words); 896 os << tabs 897 << formatv( 898 " return emitError(unknownLoc, \"expected result type <id> " 899 "while deserializing {0}\");\n", 900 op.getQualCppClassName()); 901 os << tabs << " }\n"; 902 os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex); 903 os << tabs << " if (!ty) {\n"; 904 os << tabs 905 << formatv( 906 " return emitError(unknownLoc, \"unknown type result <id> : " 907 "\") << {0}[{1}];\n", 908 words, wordIndex); 909 os << tabs << " }\n"; 910 os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes); 911 os << tabs << formatv(" {0}++;\n", wordIndex); 912 os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words); 913 os << tabs 914 << formatv( 915 " return emitError(unknownLoc, \"expected result <id> while " 916 "deserializing {0}\");\n", 917 op.getQualCppClassName()); 918 os << tabs << " }\n"; 919 os << tabs << "}\n"; 920 os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex); 921 } else if (op.getNumResults() != 0) { 922 PrintFatalError(loc, "SPIR-V ops can have only zero or one result"); 923 } 924 } 925 926 /// Generates the code to deserialize the operands of an SPIRV_Op `op` into 927 /// `os`. The generated code reads the `words` of the binary instruction, from 928 /// position `wordIndex` to the end, and either gets the Value corresponding to 929 /// the ID encoded, or deserializes the attributes encoded. The parsed 930 /// operand(attribute) is added to the `operands` list or `attributes` list. 931 static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc, 932 StringRef tabs, StringRef words, 933 StringRef wordIndex, StringRef operands, 934 StringRef attributes, raw_ostream &os) { 935 // Process operands/attributes 936 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { 937 auto argument = op.getArg(i); 938 if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) { 939 if (valueArg->isVariableLength()) { 940 if (i != e - 1) { 941 PrintFatalError( 942 loc, "SPIR-V ops can have Variadic<..> or " 943 "Optional<...> arguments only if it's the last argument"); 944 } 945 os << tabs 946 << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words); 947 } else { 948 os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words); 949 } 950 os << " {\n"; 951 os << tabs 952 << formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex); 953 os << tabs << " if (!arg) {\n"; 954 os << tabs 955 << formatv( 956 " return emitError(unknownLoc, \"unknown result <id> : \") " 957 "<< {0}[{1}];\n", 958 words, wordIndex); 959 os << tabs << " }\n"; 960 os << tabs << formatv(" {0}.push_back(arg);\n", operands); 961 if (!valueArg->isVariableLength()) { 962 os << tabs << formatv(" {0}++;\n", wordIndex); 963 } 964 os << tabs << "}\n"; 965 } else { 966 os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words); 967 auto *attr = cast<NamedAttribute *>(argument); 968 auto newtabs = tabs.str() + " "; 969 emitAttributeDeserialization( 970 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), 971 loc, newtabs, attributes, attr->name, words, wordIndex, os); 972 os << " }\n"; 973 } 974 } 975 976 os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words); 977 os << tabs 978 << formatv( 979 " return emitError(unknownLoc, \"found more operands than " 980 "expected when deserializing {0}, only \") << {1} << \" of \" << " 981 "{2}.size() << \" processed\";\n", 982 op.getQualCppClassName(), wordIndex, words); 983 os << tabs << "}\n\n"; 984 } 985 986 /// Generates code to update the `attributes` vector with the attributes 987 /// obtained from parsing the decorations in the SPIR-V binary associated with 988 /// an <id> `valueID` 989 static void emitDecorationDeserialization(const Operator &op, StringRef tabs, 990 StringRef valueID, 991 StringRef attributes, 992 raw_ostream &os) { 993 // Import decorations parsed 994 if (op.getNumResults() == 1) { 995 os << tabs << formatv("if (decorations.count({0})) {{\n", valueID); 996 os << tabs 997 << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID); 998 os << tabs 999 << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes); 1000 os << tabs << "}\n"; 1001 } 1002 } 1003 1004 /// Generates code to deserialize an SPIRV_Op `op` into `os`. 1005 static void emitDeserializationFunction(const Record *attrClass, 1006 const Record *record, 1007 const Operator &op, raw_ostream &os) { 1008 // If the record has 'autogenSerialization' set to 0, nothing to do 1009 if (!record->getValueAsBit("autogenSerialization")) 1010 return; 1011 1012 StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"), 1013 wordIndex("wordIndex"), opVar("op"), operands("operands"), 1014 attributes("attributes"); 1015 1016 // Method declaration 1017 os << formatv("template <> " 1018 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" 1019 "uint32_t> {1}) {{\n", 1020 op.getQualCppClassName(), words); 1021 1022 // Special case for ops without attributes in TableGen definitions 1023 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { 1024 os << formatv(" return processOpWithoutGrammarAttr(" 1025 "{0}, \"{1}\", {2}, {3});\n}\n\n", 1026 words, op.getOperationName(), 1027 op.getNumResults() ? "true" : "false", op.getNumOperands()); 1028 return; 1029 } 1030 1031 os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes); 1032 os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex); 1033 os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID); 1034 1035 // Deserialize result information 1036 emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex, 1037 resultTypes, valueID, os); 1038 1039 os << formatv(" SmallVector<Value, 4> {0};\n", operands); 1040 os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes); 1041 // Operand deserialization 1042 emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, 1043 operands, attributes, os); 1044 1045 // Decorations 1046 emitDecorationDeserialization(op, " ", valueID, attributes, os); 1047 1048 os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n"); 1049 os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " 1050 "(void){1};\n", 1051 op.getQualCppClassName(), opVar, resultTypes, operands, 1052 attributes); 1053 if (op.getNumResults() == 1) { 1054 os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar); 1055 } 1056 1057 // According to SPIR-V spec: 1058 // This location information applies to the instructions physically following 1059 // this instruction, up to the first occurrence of any of the following: the 1060 // next end of block. 1061 os << formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar); 1062 os << formatv(" (void)clearDebugLine();\n"); 1063 os << " return success();\n"; 1064 os << "}\n\n"; 1065 } 1066 1067 /// Generates the prologue for the function that dispatches the deserialization 1068 /// based on the `opcode`. 1069 static void initDispatchDeserializationFn(StringRef opcode, StringRef words, 1070 raw_ostream &os) { 1071 os << formatv("LogicalResult spirv::Deserializer::" 1072 "dispatchToAutogenDeserialization(spirv::Opcode {0}," 1073 " ArrayRef<uint32_t> {1}) {{\n", 1074 opcode, words); 1075 os << formatv(" switch ({0}) {{\n", opcode); 1076 } 1077 1078 /// Generates the body of the dispatch function, by generating the case label 1079 /// for an opcode and the call to the method to perform the deserialization. 1080 static void emitDeserializationDispatch(const Operator &op, const Record *def, 1081 StringRef tabs, StringRef words, 1082 raw_ostream &os) { 1083 os << tabs 1084 << formatv("case spirv::Opcode::{0}:\n", 1085 def->getValueAsString("spirvOpName")); 1086 os << tabs 1087 << formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(), 1088 words); 1089 } 1090 1091 /// Generates the epilogue for the function that dispatches the deserialization 1092 /// of the operation. 1093 static void finalizeDispatchDeserializationFn(StringRef opcode, 1094 raw_ostream &os) { 1095 os << " default:\n"; 1096 os << " ;\n"; 1097 os << " }\n"; 1098 StringRef opcodeVar("opcodeString"); 1099 os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar, 1100 opcode); 1101 os << formatv(" if (!{0}.empty()) {{\n", opcodeVar); 1102 os << formatv(" return emitError(unknownLoc, \"unhandled deserialization " 1103 "of \") << {0};\n", 1104 opcodeVar); 1105 os << " } else {\n"; 1106 os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << " 1107 "static_cast<uint32_t>({0});\n", 1108 opcode); 1109 os << " }\n"; 1110 os << "}\n"; 1111 } 1112 1113 static void initExtendedSetDeserializationDispatch(StringRef extensionSetName, 1114 StringRef instructionID, 1115 StringRef words, 1116 raw_ostream &os) { 1117 os << formatv("LogicalResult spirv::Deserializer::" 1118 "dispatchToExtensionSetAutogenDeserialization(" 1119 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n", 1120 extensionSetName, instructionID, words); 1121 } 1122 1123 static void emitExtendedSetDeserializationDispatch(const RecordKeeper &records, 1124 raw_ostream &os) { 1125 StringRef extensionSetName("extensionSetName"), 1126 instructionID("instructionID"), words("words"); 1127 1128 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all 1129 // extensionSets. 1130 1131 // For each of the extensions a separate raw_string_ostream is used to 1132 // generate code into. These are then concatenated at the end. Since 1133 // raw_string_ostream needs a string&, use a vector to store all the string 1134 // that are captured by reference within raw_string_ostream. 1135 StringMap<raw_string_ostream> extensionSets; 1136 std::list<std::string> extensionSetNames; 1137 1138 initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words, 1139 os); 1140 auto defs = records.getAllDerivedDefinitions("SPIRV_ExtInstOp"); 1141 for (const auto *def : defs) { 1142 if (!def->getValueAsBit("autogenSerialization")) { 1143 continue; 1144 } 1145 Operator op(def); 1146 auto setName = def->getValueAsString("extendedInstSetName"); 1147 if (!extensionSets.count(setName)) { 1148 extensionSetNames.emplace_back(""); 1149 extensionSets.try_emplace(setName, extensionSetNames.back()); 1150 auto &setos = extensionSets.find(setName)->second; 1151 setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName); 1152 setos << formatv(" switch ({0}) {{\n", instructionID); 1153 } 1154 auto &setos = extensionSets.find(setName)->second; 1155 setos << formatv(" case {0}:\n", 1156 def->getValueAsInt("extendedInstOpcode")); 1157 setos << formatv(" return processOp<{0}>({1});\n", 1158 op.getQualCppClassName(), words); 1159 } 1160 1161 // Append the dispatch code for all the extended sets. 1162 for (auto &extensionSet : extensionSets) { 1163 os << extensionSet.second.str(); 1164 os << " default:\n"; 1165 os << formatv( 1166 " return emitError(unknownLoc, \"unhandled deserializations of " 1167 "\") << {0} << \" from extension set \" << {1};\n", 1168 instructionID, extensionSetName); 1169 os << " }\n"; 1170 os << " }\n"; 1171 } 1172 1173 os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of " 1174 "extended instruction set {0}\");\n", 1175 extensionSetName); 1176 os << "}\n"; 1177 } 1178 1179 /// Emits all the autogenerated serialization/deserializations functions for the 1180 /// SPIRV_Ops. 1181 static bool emitSerializationFns(const RecordKeeper &records, raw_ostream &os) { 1182 llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os, 1183 records); 1184 1185 std::string dSerFnString, dDesFnString, serFnString, deserFnString, 1186 utilsString; 1187 raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), 1188 serFn(serFnString), deserFn(deserFnString); 1189 const Record *attrClass = records.getClass("Attr"); 1190 1191 // Emit the serialization and deserialization functions simultaneously. 1192 StringRef opVar("op"); 1193 StringRef opcode("opcode"), words("words"); 1194 1195 // Handle the SPIR-V ops. 1196 initDispatchSerializationFn(opVar, dSerFn); 1197 initDispatchDeserializationFn(opcode, words, dDesFn); 1198 auto defs = records.getAllDerivedDefinitions("SPIRV_Op"); 1199 for (const auto *def : defs) { 1200 Operator op(def); 1201 emitSerializationFunction(attrClass, def, op, serFn); 1202 emitDeserializationFunction(attrClass, def, op, deserFn); 1203 if (def->getValueAsBit("hasOpcode") || 1204 def->isSubClassOf("SPIRV_ExtInstOp")) { 1205 emitSerializationDispatch(op, " ", opVar, dSerFn); 1206 } 1207 if (def->getValueAsBit("hasOpcode")) { 1208 emitDeserializationDispatch(op, def, " ", words, dDesFn); 1209 } 1210 } 1211 finalizeDispatchSerializationFn(opVar, dSerFn); 1212 finalizeDispatchDeserializationFn(opcode, dDesFn); 1213 1214 emitExtendedSetDeserializationDispatch(records, dDesFn); 1215 1216 os << "#ifdef GET_SERIALIZATION_FNS\n\n"; 1217 os << serFn.str(); 1218 os << dSerFn.str(); 1219 os << "#endif // GET_SERIALIZATION_FNS\n\n"; 1220 1221 os << "#ifdef GET_DESERIALIZATION_FNS\n\n"; 1222 os << deserFn.str(); 1223 os << dDesFn.str(); 1224 os << "#endif // GET_DESERIALIZATION_FNS\n\n"; 1225 1226 return false; 1227 } 1228 1229 //===----------------------------------------------------------------------===// 1230 // Serialization Hook Registration 1231 //===----------------------------------------------------------------------===// 1232 1233 static mlir::GenRegistration genSerialization( 1234 "gen-spirv-serialization", 1235 "Generate SPIR-V (de)serialization utilities and functions", 1236 [](const RecordKeeper &records, raw_ostream &os) { 1237 return emitSerializationFns(records, os); 1238 }); 1239 1240 //===----------------------------------------------------------------------===// 1241 // Op Utils AutoGen 1242 //===----------------------------------------------------------------------===// 1243 1244 static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { 1245 os << formatv("template <typename EnumClass> inline constexpr StringRef " 1246 "attributeName();\n"); 1247 } 1248 1249 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, 1250 raw_ostream &os) { 1251 auto enumName = enumAttr.getEnumClassName(); 1252 os << formatv("template <> inline StringRef attributeName<{0}>() {{\n", 1253 enumName); 1254 os << " " 1255 << formatv("static constexpr const char attrName[] = \"{0}\";\n", 1256 llvm::convertToSnakeFromCamelCase(enumName)); 1257 os << " return attrName;\n"; 1258 os << "}\n"; 1259 } 1260 1261 static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) { 1262 llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os, records); 1263 1264 auto defs = records.getAllDerivedDefinitions("EnumAttrInfo"); 1265 os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n"; 1266 os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n"; 1267 emitEnumGetAttrNameFnDecl(os); 1268 for (const auto *def : defs) { 1269 EnumAttr enumAttr(*def); 1270 emitEnumGetAttrNameFnDefn(enumAttr, os); 1271 } 1272 os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n"; 1273 return false; 1274 } 1275 1276 //===----------------------------------------------------------------------===// 1277 // Op Utils Hook Registration 1278 //===----------------------------------------------------------------------===// 1279 1280 static mlir::GenRegistration 1281 genOpUtils("gen-spirv-attr-utils", 1282 "Generate SPIR-V attribute utility definitions", 1283 [](const RecordKeeper &records, raw_ostream &os) { 1284 return emitAttrUtils(records, os); 1285 }); 1286 1287 //===----------------------------------------------------------------------===// 1288 // SPIR-V Availability Impl AutoGen 1289 //===----------------------------------------------------------------------===// 1290 1291 static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { 1292 mlir::tblgen::FmtContext fctx; 1293 fctx.addSubst("overall", "tblgen_overall"); 1294 1295 std::vector<Availability> opAvailabilities = 1296 getAvailabilities(srcOp.getDef()); 1297 1298 // First collect all availability classes this op should implement. 1299 // All availability instances keep information for the generated interface and 1300 // the instance's specific requirement. Here we remember a random instance so 1301 // we can get the information regarding the generated interface. 1302 llvm::StringMap<Availability> availClasses; 1303 for (const Availability &avail : opAvailabilities) 1304 availClasses.try_emplace(avail.getClass(), avail); 1305 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { 1306 if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") && 1307 !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr")) 1308 continue; 1309 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum")); 1310 1311 for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) 1312 for (const Availability &caseAvail : 1313 getAvailabilities(enumerant.getDef())) 1314 availClasses.try_emplace(caseAvail.getClass(), caseAvail); 1315 } 1316 1317 // Then generate implementation for each availability class. 1318 for (const auto &availClass : availClasses) { 1319 StringRef availClassName = availClass.getKey(); 1320 Availability avail = availClass.getValue(); 1321 1322 // Generate the implementation method signature. 1323 os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(), 1324 srcOp.getCppClassName(), avail.getQueryFnName()); 1325 1326 // Create the variable for the final requirement and initialize it. 1327 os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(), 1328 avail.getMergeInitializer()); 1329 1330 // Update with the op's specific availability spec. 1331 for (const Availability &avail : opAvailabilities) 1332 if (avail.getClass() == availClassName && 1333 (!avail.getMergeInstancePreparation().empty() || 1334 !avail.getMergeActionCode().empty())) { 1335 os << " {\n " 1336 // Prepare this instance. 1337 << avail.getMergeInstancePreparation() 1338 << "\n " 1339 // Merge this instance. 1340 << std::string( 1341 tgfmt(avail.getMergeActionCode(), 1342 &fctx.addSubst("instance", avail.getMergeInstance()))) 1343 << ";\n }\n"; 1344 } 1345 1346 // Update with enum attributes' specific availability spec. 1347 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { 1348 if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") && 1349 !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr")) 1350 continue; 1351 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum")); 1352 1353 // (enumerant, availability specification) pairs for this availability 1354 // class. 1355 SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs; 1356 1357 // Collect all cases' availability specs. 1358 for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) 1359 for (const Availability &caseAvail : 1360 getAvailabilities(enumerant.getDef())) 1361 if (availClassName == caseAvail.getClass()) 1362 caseSpecs.push_back({enumerant, caseAvail}); 1363 1364 // If this attribute kind does not have any availability spec from any of 1365 // its cases, no more work to do. 1366 if (caseSpecs.empty()) 1367 continue; 1368 1369 if (enumAttr.isBitEnum()) { 1370 // For BitEnumAttr, we need to iterate over each bit to query its 1371 // availability spec. 1372 os << formatv(" for (unsigned i = 0; " 1373 "i < std::numeric_limits<{0}>::digits; ++i) {{\n", 1374 enumAttr.getUnderlyingType()); 1375 os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " 1376 "static_cast<{0}::{1}>(1 << i);\n", 1377 enumAttr.getCppNamespace(), enumAttr.getEnumClassName(), 1378 srcOp.getGetterName(namedAttr.name)); 1379 os << formatv( 1380 " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", 1381 enumAttr.getUnderlyingType()); 1382 } else { 1383 // For IntEnumAttr, we just need to query the value as a whole. 1384 os << " {\n"; 1385 os << formatv(" auto tblgen_attrVal = this->{0}();\n", 1386 srcOp.getGetterName(namedAttr.name)); 1387 } 1388 os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", 1389 enumAttr.getCppNamespace(), avail.getQueryFnName()); 1390 os << " if (tblgen_instance) " 1391 // TODO` here once ODS supports 1392 // dialect-specific contents so that we can use not implementing the 1393 // availability interface as indication of no requirements. 1394 << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(), 1395 &fctx.addSubst("instance", "*tblgen_instance"))) 1396 << ";\n"; 1397 os << " }\n"; 1398 } 1399 1400 os << " return tblgen_overall;\n"; 1401 os << "}\n"; 1402 } 1403 } 1404 1405 static bool emitAvailabilityImpl(const RecordKeeper &records, raw_ostream &os) { 1406 llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os, 1407 records); 1408 1409 auto defs = records.getAllDerivedDefinitions("SPIRV_Op"); 1410 for (const auto *def : defs) { 1411 Operator op(def); 1412 if (def->getValueAsBit("autogenAvailability")) 1413 emitAvailabilityImpl(op, os); 1414 } 1415 return false; 1416 } 1417 1418 //===----------------------------------------------------------------------===// 1419 // Op Availability Implementation Hook Registration 1420 //===----------------------------------------------------------------------===// 1421 1422 static mlir::GenRegistration 1423 genOpAvailabilityImpl("gen-spirv-avail-impls", 1424 "Generate SPIR-V operation utility definitions", 1425 [](const RecordKeeper &records, raw_ostream &os) { 1426 return emitAvailabilityImpl(records, os); 1427 }); 1428 1429 //===----------------------------------------------------------------------===// 1430 // SPIR-V Capability Implication AutoGen 1431 //===----------------------------------------------------------------------===// 1432 1433 static bool emitCapabilityImplication(const RecordKeeper &records, 1434 raw_ostream &os) { 1435 llvm::emitSourceFileHeader("SPIR-V Capability Implication", os, records); 1436 1437 EnumAttr enumAttr( 1438 records.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum")); 1439 1440 os << "ArrayRef<spirv::Capability> " 1441 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n" 1442 << " switch (cap) {\n" 1443 << " default: return {};\n"; 1444 for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) { 1445 const Record &def = enumerant.getDef(); 1446 if (!def.getValue("implies")) 1447 continue; 1448 1449 std::vector<const Record *> impliedCapsDefs = 1450 def.getValueAsListOfDefs("implies"); 1451 os << " case spirv::Capability::" << enumerant.getSymbol() 1452 << ": {static const spirv::Capability implies[" << impliedCapsDefs.size() 1453 << "] = {"; 1454 llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) { 1455 os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol(); 1456 }); 1457 os << "}; return ArrayRef<spirv::Capability>(implies, " 1458 << impliedCapsDefs.size() << "); }\n"; 1459 } 1460 os << " }\n"; 1461 os << "}\n"; 1462 1463 return false; 1464 } 1465 1466 //===----------------------------------------------------------------------===// 1467 // SPIR-V Capability Implication Hook Registration 1468 //===----------------------------------------------------------------------===// 1469 1470 static mlir::GenRegistration 1471 genCapabilityImplication("gen-spirv-capability-implication", 1472 "Generate utility function to return implied " 1473 "capabilities for a given capability", 1474 [](const RecordKeeper &records, raw_ostream &os) { 1475 return emitCapabilityImplication(records, os); 1476 }); 1477