1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeReader.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/Verifier.h" 19 #include "mlir/IR/Visitors.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/MapVector.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/Support/Endian.h" 29 #include "llvm/Support/MemoryBufferRef.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/SourceMgr.h" 32 #include <cstddef> 33 #include <list> 34 #include <memory> 35 #include <numeric> 36 #include <optional> 37 38 #define DEBUG_TYPE "mlir-bytecode-reader" 39 40 using namespace mlir; 41 42 /// Stringify the given section ID. 43 static std::string toString(bytecode::Section::ID sectionID) { 44 switch (sectionID) { 45 case bytecode::Section::kString: 46 return "String (0)"; 47 case bytecode::Section::kDialect: 48 return "Dialect (1)"; 49 case bytecode::Section::kAttrType: 50 return "AttrType (2)"; 51 case bytecode::Section::kAttrTypeOffset: 52 return "AttrTypeOffset (3)"; 53 case bytecode::Section::kIR: 54 return "IR (4)"; 55 case bytecode::Section::kResource: 56 return "Resource (5)"; 57 case bytecode::Section::kResourceOffset: 58 return "ResourceOffset (6)"; 59 case bytecode::Section::kDialectVersions: 60 return "DialectVersions (7)"; 61 case bytecode::Section::kProperties: 62 return "Properties (8)"; 63 default: 64 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 65 } 66 } 67 68 /// Returns true if the given top-level section ID is optional. 69 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 70 switch (sectionID) { 71 case bytecode::Section::kString: 72 case bytecode::Section::kDialect: 73 case bytecode::Section::kAttrType: 74 case bytecode::Section::kAttrTypeOffset: 75 case bytecode::Section::kIR: 76 return false; 77 case bytecode::Section::kResource: 78 case bytecode::Section::kResourceOffset: 79 case bytecode::Section::kDialectVersions: 80 return true; 81 case bytecode::Section::kProperties: 82 return version < bytecode::kNativePropertiesEncoding; 83 default: 84 llvm_unreachable("unknown section ID"); 85 } 86 } 87 88 //===----------------------------------------------------------------------===// 89 // EncodingReader 90 //===----------------------------------------------------------------------===// 91 92 namespace { 93 class EncodingReader { 94 public: 95 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 96 : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} 97 explicit EncodingReader(StringRef contents, Location fileLoc) 98 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 99 contents.size()}, 100 fileLoc) {} 101 102 /// Returns true if the entire section has been read. 103 bool empty() const { return dataIt == dataEnd; } 104 105 /// Returns the remaining size of the bytecode. 106 size_t size() const { return dataEnd - dataIt; } 107 108 /// Align the current reader position to the specified alignment. 109 LogicalResult alignTo(unsigned alignment) { 110 if (!llvm::isPowerOf2_32(alignment)) 111 return emitError("expected alignment to be a power-of-two"); 112 113 // Shift the reader position to the next alignment boundary. 114 while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { 115 uint8_t padding; 116 if (failed(parseByte(padding))) 117 return failure(); 118 if (padding != bytecode::kAlignmentByte) { 119 return emitError("expected alignment byte (0xCB), but got: '0x" + 120 llvm::utohexstr(padding) + "'"); 121 } 122 } 123 124 // Ensure the data iterator is now aligned. This case is unlikely because we 125 // *just* went through the effort to align the data iterator. 126 if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) { 127 return emitError("expected data iterator aligned to ", alignment, 128 ", but got pointer: '0x" + 129 llvm::utohexstr((uintptr_t)dataIt) + "'"); 130 } 131 132 return success(); 133 } 134 135 /// Emit an error using the given arguments. 136 template <typename... Args> 137 InFlightDiagnostic emitError(Args &&...args) const { 138 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 139 } 140 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 141 142 /// Parse a single byte from the stream. 143 template <typename T> 144 LogicalResult parseByte(T &value) { 145 if (empty()) 146 return emitError("attempting to parse a byte at the end of the bytecode"); 147 value = static_cast<T>(*dataIt++); 148 return success(); 149 } 150 /// Parse a range of bytes of 'length' into the given result. 151 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 152 if (length > size()) { 153 return emitError("attempting to parse ", length, " bytes when only ", 154 size(), " remain"); 155 } 156 result = {dataIt, length}; 157 dataIt += length; 158 return success(); 159 } 160 /// Parse a range of bytes of 'length' into the given result, which can be 161 /// assumed to be large enough to hold `length`. 162 LogicalResult parseBytes(size_t length, uint8_t *result) { 163 if (length > size()) { 164 return emitError("attempting to parse ", length, " bytes when only ", 165 size(), " remain"); 166 } 167 memcpy(result, dataIt, length); 168 dataIt += length; 169 return success(); 170 } 171 172 /// Parse an aligned blob of data, where the alignment was encoded alongside 173 /// the data. 174 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 175 uint64_t &alignment) { 176 uint64_t dataSize; 177 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 178 failed(alignTo(alignment))) 179 return failure(); 180 return parseBytes(dataSize, data); 181 } 182 183 /// Parse a variable length encoded integer from the byte stream. The first 184 /// encoded byte contains a prefix in the low bits indicating the encoded 185 /// length of the value. This length prefix is a bit sequence of '0's followed 186 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 187 /// (not including the prefix byte). All remaining bits in the first byte, 188 /// along with all of the bits in additional bytes, provide the value of the 189 /// integer encoded in little-endian order. 190 LogicalResult parseVarInt(uint64_t &result) { 191 // Parse the first byte of the encoding, which contains the length prefix. 192 if (failed(parseByte(result))) 193 return failure(); 194 195 // Handle the overwhelmingly common case where the value is stored in a 196 // single byte. In this case, the first bit is the `1` marker bit. 197 if (LLVM_LIKELY(result & 1)) { 198 result >>= 1; 199 return success(); 200 } 201 202 // Handle the overwhelming uncommon case where the value required all 8 203 // bytes (i.e. a really really big number). In this case, the marker byte is 204 // all zeros: `00000000`. 205 if (LLVM_UNLIKELY(result == 0)) { 206 llvm::support::ulittle64_t resultLE; 207 if (failed(parseBytes(sizeof(resultLE), 208 reinterpret_cast<uint8_t *>(&resultLE)))) 209 return failure(); 210 result = resultLE; 211 return success(); 212 } 213 return parseMultiByteVarInt(result); 214 } 215 216 /// Parse a signed variable length encoded integer from the byte stream. A 217 /// signed varint is encoded as a normal varint with zigzag encoding applied, 218 /// i.e. the low bit of the value is used to indicate the sign. 219 LogicalResult parseSignedVarInt(uint64_t &result) { 220 if (failed(parseVarInt(result))) 221 return failure(); 222 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 223 result = (result >> 1) ^ (~(result & 1) + 1); 224 return success(); 225 } 226 227 /// Parse a variable length encoded integer whose low bit is used to encode an 228 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 229 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 230 if (failed(parseVarInt(result))) 231 return failure(); 232 flag = result & 1; 233 result >>= 1; 234 return success(); 235 } 236 237 /// Skip the first `length` bytes within the reader. 238 LogicalResult skipBytes(size_t length) { 239 if (length > size()) { 240 return emitError("attempting to skip ", length, " bytes when only ", 241 size(), " remain"); 242 } 243 dataIt += length; 244 return success(); 245 } 246 247 /// Parse a null-terminated string into `result` (without including the NUL 248 /// terminator). 249 LogicalResult parseNullTerminatedString(StringRef &result) { 250 const char *startIt = (const char *)dataIt; 251 const char *nulIt = (const char *)memchr(startIt, 0, size()); 252 if (!nulIt) 253 return emitError( 254 "malformed null-terminated string, no null character found"); 255 256 result = StringRef(startIt, nulIt - startIt); 257 dataIt = (const uint8_t *)nulIt + 1; 258 return success(); 259 } 260 261 /// Parse a section header, placing the kind of section in `sectionID` and the 262 /// contents of the section in `sectionData`. 263 LogicalResult parseSection(bytecode::Section::ID §ionID, 264 ArrayRef<uint8_t> §ionData) { 265 uint8_t sectionIDAndHasAlignment; 266 uint64_t length; 267 if (failed(parseByte(sectionIDAndHasAlignment)) || 268 failed(parseVarInt(length))) 269 return failure(); 270 271 // Extract the section ID and whether the section is aligned. The high bit 272 // of the ID is the alignment flag. 273 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 274 0b01111111); 275 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 276 277 // Check that the section is actually valid before trying to process its 278 // data. 279 if (sectionID >= bytecode::Section::kNumSections) 280 return emitError("invalid section ID: ", unsigned(sectionID)); 281 282 // Process the section alignment if present. 283 if (hasAlignment) { 284 uint64_t alignment; 285 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 286 return failure(); 287 } 288 289 // Parse the actual section data. 290 return parseBytes(static_cast<size_t>(length), sectionData); 291 } 292 293 Location getLoc() const { return fileLoc; } 294 295 private: 296 /// Parse a variable length encoded integer from the byte stream. This method 297 /// is a fallback when the number of bytes used to encode the value is greater 298 /// than 1, but less than the max (9). The provided `result` value can be 299 /// assumed to already contain the first byte of the value. 300 /// NOTE: This method is marked noinline to avoid pessimizing the common case 301 /// of single byte encoding. 302 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 303 // Count the number of trailing zeros in the marker byte, this indicates the 304 // number of trailing bytes that are part of the value. We use `uint32_t` 305 // here because we only care about the first byte, and so that be actually 306 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 307 // implementation). 308 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 309 assert(numBytes > 0 && numBytes <= 7 && 310 "unexpected number of trailing zeros in varint encoding"); 311 312 // Parse in the remaining bytes of the value. 313 llvm::support::ulittle64_t resultLE(result); 314 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1))) 315 return failure(); 316 317 // Shift out the low-order bits that were used to mark how the value was 318 // encoded. 319 result = resultLE >> (numBytes + 1); 320 return success(); 321 } 322 323 /// The current data iterator, and an iterator to the end of the buffer. 324 const uint8_t *dataIt, *dataEnd; 325 326 /// A location for the bytecode used to report errors. 327 Location fileLoc; 328 }; 329 } // namespace 330 331 /// Resolve an index into the given entry list. `entry` may either be a 332 /// reference, in which case it is assigned to the corresponding value in 333 /// `entries`, or a pointer, in which case it is assigned to the address of the 334 /// element in `entries`. 335 template <typename RangeT, typename T> 336 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 337 uint64_t index, T &entry, 338 StringRef entryStr) { 339 if (index >= entries.size()) 340 return reader.emitError("invalid ", entryStr, " index: ", index); 341 342 // If the provided entry is a pointer, resolve to the address of the entry. 343 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 344 entry = entries[index]; 345 else 346 entry = &entries[index]; 347 return success(); 348 } 349 350 /// Parse and resolve an index into the given entry list. 351 template <typename RangeT, typename T> 352 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 353 T &entry, StringRef entryStr) { 354 uint64_t entryIdx; 355 if (failed(reader.parseVarInt(entryIdx))) 356 return failure(); 357 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // StringSectionReader 362 //===----------------------------------------------------------------------===// 363 364 namespace { 365 /// This class is used to read references to the string section from the 366 /// bytecode. 367 class StringSectionReader { 368 public: 369 /// Initialize the string section reader with the given section data. 370 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 371 372 /// Parse a shared string from the string section. The shared string is 373 /// encoded using an index to a corresponding string in the string section. 374 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 375 return parseEntry(reader, strings, result, "string"); 376 } 377 378 /// Parse a shared string from the string section. The shared string is 379 /// encoded using an index to a corresponding string in the string section. 380 /// This variant parses a flag compressed with the index. 381 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 382 bool &flag) { 383 uint64_t entryIdx; 384 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 385 return failure(); 386 return parseStringAtIndex(reader, entryIdx, result); 387 } 388 389 /// Parse a shared string from the string section. The shared string is 390 /// encoded using an index to a corresponding string in the string section. 391 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 392 StringRef &result) { 393 return resolveEntry(reader, strings, index, result, "string"); 394 } 395 396 private: 397 /// The table of strings referenced within the bytecode file. 398 SmallVector<StringRef> strings; 399 }; 400 } // namespace 401 402 LogicalResult StringSectionReader::initialize(Location fileLoc, 403 ArrayRef<uint8_t> sectionData) { 404 EncodingReader stringReader(sectionData, fileLoc); 405 406 // Parse the number of strings in the section. 407 uint64_t numStrings; 408 if (failed(stringReader.parseVarInt(numStrings))) 409 return failure(); 410 strings.resize(numStrings); 411 412 // Parse each of the strings. The sizes of the strings are encoded in reverse 413 // order, so that's the order we populate the table. 414 size_t stringDataEndOffset = sectionData.size(); 415 for (StringRef &string : llvm::reverse(strings)) { 416 uint64_t stringSize; 417 if (failed(stringReader.parseVarInt(stringSize))) 418 return failure(); 419 if (stringDataEndOffset < stringSize) { 420 return stringReader.emitError( 421 "string size exceeds the available data size"); 422 } 423 424 // Extract the string from the data, dropping the null character. 425 size_t stringOffset = stringDataEndOffset - stringSize; 426 string = StringRef( 427 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 428 stringSize - 1); 429 stringDataEndOffset = stringOffset; 430 } 431 432 // Check that the only remaining data was for the strings, i.e. the reader 433 // should be at the same offset as the first string. 434 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 435 return stringReader.emitError("unexpected trailing data between the " 436 "offsets for strings and their data"); 437 } 438 return success(); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // BytecodeDialect 443 //===----------------------------------------------------------------------===// 444 445 namespace { 446 class DialectReader; 447 448 /// This struct represents a dialect entry within the bytecode. 449 struct BytecodeDialect { 450 /// Load the dialect into the provided context if it hasn't been loaded yet. 451 /// Returns failure if the dialect couldn't be loaded *and* the provided 452 /// context does not allow unregistered dialects. The provided reader is used 453 /// for error emission if necessary. 454 LogicalResult load(const DialectReader &reader, MLIRContext *ctx); 455 456 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 457 /// only be called after `load`. 458 Dialect *getLoadedDialect() const { 459 assert(dialect && 460 "expected `load` to be invoked before `getLoadedDialect`"); 461 return *dialect; 462 } 463 464 /// The loaded dialect entry. This field is std::nullopt if we haven't 465 /// attempted to load, nullptr if we failed to load, otherwise the loaded 466 /// dialect. 467 std::optional<Dialect *> dialect; 468 469 /// The bytecode interface of the dialect, or nullptr if the dialect does not 470 /// implement the bytecode interface. This field should only be checked if the 471 /// `dialect` field is not std::nullopt. 472 const BytecodeDialectInterface *interface = nullptr; 473 474 /// The name of the dialect. 475 StringRef name; 476 477 /// A buffer containing the encoding of the dialect version parsed. 478 ArrayRef<uint8_t> versionBuffer; 479 480 /// Lazy loaded dialect version from the handle above. 481 std::unique_ptr<DialectVersion> loadedVersion; 482 }; 483 484 /// This struct represents an operation name entry within the bytecode. 485 struct BytecodeOperationName { 486 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 487 std::optional<bool> wasRegistered) 488 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 489 490 /// The loaded operation name, or std::nullopt if it hasn't been processed 491 /// yet. 492 std::optional<OperationName> opName; 493 494 /// The dialect that owns this operation name. 495 BytecodeDialect *dialect; 496 497 /// The name of the operation, without the dialect prefix. 498 StringRef name; 499 500 /// Whether this operation was registered when the bytecode was produced. 501 /// This flag is populated when bytecode version >=kNativePropertiesEncoding. 502 std::optional<bool> wasRegistered; 503 }; 504 } // namespace 505 506 /// Parse a single dialect group encoded in the byte stream. 507 static LogicalResult parseDialectGrouping( 508 EncodingReader &reader, 509 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 510 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 511 // Parse the dialect and the number of entries in the group. 512 std::unique_ptr<BytecodeDialect> *dialect; 513 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 514 return failure(); 515 uint64_t numEntries; 516 if (failed(reader.parseVarInt(numEntries))) 517 return failure(); 518 519 for (uint64_t i = 0; i < numEntries; ++i) 520 if (failed(entryCallback(dialect->get()))) 521 return failure(); 522 return success(); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // ResourceSectionReader 527 //===----------------------------------------------------------------------===// 528 529 namespace { 530 /// This class is used to read the resource section from the bytecode. 531 class ResourceSectionReader { 532 public: 533 /// Initialize the resource section reader with the given section data. 534 LogicalResult 535 initialize(Location fileLoc, const ParserConfig &config, 536 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 537 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 538 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 539 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 540 541 /// Parse a dialect resource handle from the resource section. 542 LogicalResult parseResourceHandle(EncodingReader &reader, 543 AsmDialectResourceHandle &result) { 544 return parseEntry(reader, dialectResources, result, "resource handle"); 545 } 546 547 private: 548 /// The table of dialect resources within the bytecode file. 549 SmallVector<AsmDialectResourceHandle> dialectResources; 550 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 551 }; 552 553 class ParsedResourceEntry : public AsmParsedResourceEntry { 554 public: 555 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 556 EncodingReader &reader, StringSectionReader &stringReader, 557 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 558 : key(key), kind(kind), reader(reader), stringReader(stringReader), 559 bufferOwnerRef(bufferOwnerRef) {} 560 ~ParsedResourceEntry() override = default; 561 562 StringRef getKey() const final { return key; } 563 564 InFlightDiagnostic emitError() const final { return reader.emitError(); } 565 566 AsmResourceEntryKind getKind() const final { return kind; } 567 568 FailureOr<bool> parseAsBool() const final { 569 if (kind != AsmResourceEntryKind::Bool) 570 return emitError() << "expected a bool resource entry, but found a " 571 << toString(kind) << " entry instead"; 572 573 bool value; 574 if (failed(reader.parseByte(value))) 575 return failure(); 576 return value; 577 } 578 FailureOr<std::string> parseAsString() const final { 579 if (kind != AsmResourceEntryKind::String) 580 return emitError() << "expected a string resource entry, but found a " 581 << toString(kind) << " entry instead"; 582 583 StringRef string; 584 if (failed(stringReader.parseString(reader, string))) 585 return failure(); 586 return string.str(); 587 } 588 589 FailureOr<AsmResourceBlob> 590 parseAsBlob(BlobAllocatorFn allocator) const final { 591 if (kind != AsmResourceEntryKind::Blob) 592 return emitError() << "expected a blob resource entry, but found a " 593 << toString(kind) << " entry instead"; 594 595 ArrayRef<uint8_t> data; 596 uint64_t alignment; 597 if (failed(reader.parseBlobAndAlignment(data, alignment))) 598 return failure(); 599 600 // If we have an extendable reference to the buffer owner, we don't need to 601 // allocate a new buffer for the data, and can use the data directly. 602 if (bufferOwnerRef) { 603 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 604 data.size()); 605 606 // Allocate an unmanager buffer which captures a reference to the owner. 607 // For now we just mark this as immutable, but in the future we should 608 // explore marking this as mutable when desired. 609 return UnmanagedAsmResourceBlob::allocateWithAlign( 610 charData, alignment, 611 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 612 } 613 614 // Allocate memory for the blob using the provided allocator and copy the 615 // data into it. 616 AsmResourceBlob blob = allocator(data.size(), alignment); 617 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 618 blob.isMutable() && 619 "blob allocator did not return a properly aligned address"); 620 memcpy(blob.getMutableData().data(), data.data(), data.size()); 621 return blob; 622 } 623 624 private: 625 StringRef key; 626 AsmResourceEntryKind kind; 627 EncodingReader &reader; 628 StringSectionReader &stringReader; 629 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 630 }; 631 } // namespace 632 633 template <typename T> 634 static LogicalResult 635 parseResourceGroup(Location fileLoc, bool allowEmpty, 636 EncodingReader &offsetReader, EncodingReader &resourceReader, 637 StringSectionReader &stringReader, T *handler, 638 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 639 function_ref<StringRef(StringRef)> remapKey = {}, 640 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 641 uint64_t numResources; 642 if (failed(offsetReader.parseVarInt(numResources))) 643 return failure(); 644 645 for (uint64_t i = 0; i < numResources; ++i) { 646 StringRef key; 647 AsmResourceEntryKind kind; 648 uint64_t resourceOffset; 649 ArrayRef<uint8_t> data; 650 if (failed(stringReader.parseString(offsetReader, key)) || 651 failed(offsetReader.parseVarInt(resourceOffset)) || 652 failed(offsetReader.parseByte(kind)) || 653 failed(resourceReader.parseBytes(resourceOffset, data))) 654 return failure(); 655 656 // Process the resource key. 657 if ((processKeyFn && failed(processKeyFn(key)))) 658 return failure(); 659 660 // If the resource data is empty and we allow it, don't error out when 661 // parsing below, just skip it. 662 if (allowEmpty && data.empty()) 663 continue; 664 665 // Ignore the entry if we don't have a valid handler. 666 if (!handler) 667 continue; 668 669 // Otherwise, parse the resource value. 670 EncodingReader entryReader(data, fileLoc); 671 key = remapKey(key); 672 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 673 bufferOwnerRef); 674 if (failed(handler->parseResource(entry))) 675 return failure(); 676 if (!entryReader.empty()) { 677 return entryReader.emitError( 678 "unexpected trailing bytes in resource entry '", key, "'"); 679 } 680 } 681 return success(); 682 } 683 684 LogicalResult ResourceSectionReader::initialize( 685 Location fileLoc, const ParserConfig &config, 686 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 687 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 688 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 689 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 690 EncodingReader resourceReader(sectionData, fileLoc); 691 EncodingReader offsetReader(offsetSectionData, fileLoc); 692 693 // Read the number of external resource providers. 694 uint64_t numExternalResourceGroups; 695 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 696 return failure(); 697 698 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 699 // provides most of the arguments. 700 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 701 function_ref<LogicalResult(StringRef)> keyFn = {}) { 702 auto resolveKey = [&](StringRef key) -> StringRef { 703 auto it = dialectResourceHandleRenamingMap.find(key); 704 if (it == dialectResourceHandleRenamingMap.end()) 705 return ""; 706 return it->second; 707 }; 708 709 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 710 stringReader, handler, bufferOwnerRef, resolveKey, 711 keyFn); 712 }; 713 714 // Read the external resources from the bytecode. 715 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 716 StringRef key; 717 if (failed(stringReader.parseString(offsetReader, key))) 718 return failure(); 719 720 // Get the handler for these resources. 721 // TODO: Should we require handling external resources in some scenarios? 722 AsmResourceParser *handler = config.getResourceParser(key); 723 if (!handler) { 724 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 725 << "'"; 726 } 727 728 if (failed(parseGroup(handler))) 729 return failure(); 730 } 731 732 // Read the dialect resources from the bytecode. 733 MLIRContext *ctx = fileLoc->getContext(); 734 while (!offsetReader.empty()) { 735 std::unique_ptr<BytecodeDialect> *dialect; 736 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 737 failed((*dialect)->load(dialectReader, ctx))) 738 return failure(); 739 Dialect *loadedDialect = (*dialect)->getLoadedDialect(); 740 if (!loadedDialect) { 741 return resourceReader.emitError() 742 << "dialect '" << (*dialect)->name << "' is unknown"; 743 } 744 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 745 if (!handler) { 746 return resourceReader.emitError() 747 << "unexpected resources for dialect '" << (*dialect)->name << "'"; 748 } 749 750 // Ensure that each resource is declared before being processed. 751 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 752 FailureOr<AsmDialectResourceHandle> handle = 753 handler->declareResource(key); 754 if (failed(handle)) { 755 return resourceReader.emitError() 756 << "unknown 'resource' key '" << key << "' for dialect '" 757 << (*dialect)->name << "'"; 758 } 759 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 760 dialectResources.push_back(*handle); 761 return success(); 762 }; 763 764 // Parse the resources for this dialect. We allow empty resources because we 765 // just treat these as declarations. 766 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 767 return failure(); 768 } 769 770 return success(); 771 } 772 773 //===----------------------------------------------------------------------===// 774 // Attribute/Type Reader 775 //===----------------------------------------------------------------------===// 776 777 namespace { 778 /// This class provides support for reading attribute and type entries from the 779 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 780 /// this reader to manage when to actually parse them from the bytecode. 781 class AttrTypeReader { 782 /// This class represents a single attribute or type entry. 783 template <typename T> 784 struct Entry { 785 /// The entry, or null if it hasn't been resolved yet. 786 T entry = {}; 787 /// The parent dialect of this entry. 788 BytecodeDialect *dialect = nullptr; 789 /// A flag indicating if the entry was encoded using a custom encoding, 790 /// instead of using the textual assembly format. 791 bool hasCustomEncoding = false; 792 /// The raw data of this entry in the bytecode. 793 ArrayRef<uint8_t> data; 794 }; 795 using AttrEntry = Entry<Attribute>; 796 using TypeEntry = Entry<Type>; 797 798 public: 799 AttrTypeReader(StringSectionReader &stringReader, 800 ResourceSectionReader &resourceReader, 801 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 802 uint64_t &bytecodeVersion, Location fileLoc, 803 const ParserConfig &config) 804 : stringReader(stringReader), resourceReader(resourceReader), 805 dialectsMap(dialectsMap), fileLoc(fileLoc), 806 bytecodeVersion(bytecodeVersion), parserConfig(config) {} 807 808 /// Initialize the attribute and type information within the reader. 809 LogicalResult 810 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 811 ArrayRef<uint8_t> sectionData, 812 ArrayRef<uint8_t> offsetSectionData); 813 814 /// Resolve the attribute or type at the given index. Returns nullptr on 815 /// failure. 816 Attribute resolveAttribute(size_t index) { 817 return resolveEntry(attributes, index, "Attribute"); 818 } 819 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 820 821 /// Parse a reference to an attribute or type using the given reader. 822 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 823 uint64_t attrIdx; 824 if (failed(reader.parseVarInt(attrIdx))) 825 return failure(); 826 result = resolveAttribute(attrIdx); 827 return success(!!result); 828 } 829 LogicalResult parseOptionalAttribute(EncodingReader &reader, 830 Attribute &result) { 831 uint64_t attrIdx; 832 bool flag; 833 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 834 return failure(); 835 if (!flag) 836 return success(); 837 result = resolveAttribute(attrIdx); 838 return success(!!result); 839 } 840 841 LogicalResult parseType(EncodingReader &reader, Type &result) { 842 uint64_t typeIdx; 843 if (failed(reader.parseVarInt(typeIdx))) 844 return failure(); 845 result = resolveType(typeIdx); 846 return success(!!result); 847 } 848 849 template <typename T> 850 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 851 Attribute baseResult; 852 if (failed(parseAttribute(reader, baseResult))) 853 return failure(); 854 if ((result = dyn_cast<T>(baseResult))) 855 return success(); 856 return reader.emitError("expected attribute of type: ", 857 llvm::getTypeName<T>(), ", but got: ", baseResult); 858 } 859 860 private: 861 /// Resolve the given entry at `index`. 862 template <typename T> 863 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 864 StringRef entryType); 865 866 /// Parse an entry using the given reader that was encoded using the textual 867 /// assembly format. 868 template <typename T> 869 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 870 StringRef entryType); 871 872 /// Parse an entry using the given reader that was encoded using a custom 873 /// bytecode format. 874 template <typename T> 875 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 876 StringRef entryType); 877 878 /// The string section reader used to resolve string references when parsing 879 /// custom encoded attribute/type entries. 880 StringSectionReader &stringReader; 881 882 /// The resource section reader used to resolve resource references when 883 /// parsing custom encoded attribute/type entries. 884 ResourceSectionReader &resourceReader; 885 886 /// The map of the loaded dialects used to retrieve dialect information, such 887 /// as the dialect version. 888 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 889 890 /// The set of attribute and type entries. 891 SmallVector<AttrEntry> attributes; 892 SmallVector<TypeEntry> types; 893 894 /// A location used for error emission. 895 Location fileLoc; 896 897 /// Current bytecode version being used. 898 uint64_t &bytecodeVersion; 899 900 /// Reference to the parser configuration. 901 const ParserConfig &parserConfig; 902 }; 903 904 class DialectReader : public DialectBytecodeReader { 905 public: 906 DialectReader(AttrTypeReader &attrTypeReader, 907 StringSectionReader &stringReader, 908 ResourceSectionReader &resourceReader, 909 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 910 EncodingReader &reader, uint64_t &bytecodeVersion) 911 : attrTypeReader(attrTypeReader), stringReader(stringReader), 912 resourceReader(resourceReader), dialectsMap(dialectsMap), 913 reader(reader), bytecodeVersion(bytecodeVersion) {} 914 915 InFlightDiagnostic emitError(const Twine &msg) const override { 916 return reader.emitError(msg); 917 } 918 919 FailureOr<const DialectVersion *> 920 getDialectVersion(StringRef dialectName) const override { 921 // First check if the dialect is available in the map. 922 auto dialectEntry = dialectsMap.find(dialectName); 923 if (dialectEntry == dialectsMap.end()) 924 return failure(); 925 // If the dialect was found, try to load it. This will trigger reading the 926 // bytecode version from the version buffer if it wasn't already processed. 927 // Return failure if either of those two actions could not be completed. 928 if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || 929 dialectEntry->getValue()->loadedVersion.get() == nullptr) 930 return failure(); 931 return dialectEntry->getValue()->loadedVersion.get(); 932 } 933 934 MLIRContext *getContext() const override { return getLoc().getContext(); } 935 936 uint64_t getBytecodeVersion() const override { return bytecodeVersion; } 937 938 DialectReader withEncodingReader(EncodingReader &encReader) const { 939 return DialectReader(attrTypeReader, stringReader, resourceReader, 940 dialectsMap, encReader, bytecodeVersion); 941 } 942 943 Location getLoc() const { return reader.getLoc(); } 944 945 //===--------------------------------------------------------------------===// 946 // IR 947 //===--------------------------------------------------------------------===// 948 949 LogicalResult readAttribute(Attribute &result) override { 950 return attrTypeReader.parseAttribute(reader, result); 951 } 952 LogicalResult readOptionalAttribute(Attribute &result) override { 953 return attrTypeReader.parseOptionalAttribute(reader, result); 954 } 955 LogicalResult readType(Type &result) override { 956 return attrTypeReader.parseType(reader, result); 957 } 958 959 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 960 AsmDialectResourceHandle handle; 961 if (failed(resourceReader.parseResourceHandle(reader, handle))) 962 return failure(); 963 return handle; 964 } 965 966 //===--------------------------------------------------------------------===// 967 // Primitives 968 //===--------------------------------------------------------------------===// 969 970 LogicalResult readVarInt(uint64_t &result) override { 971 return reader.parseVarInt(result); 972 } 973 974 LogicalResult readSignedVarInt(int64_t &result) override { 975 uint64_t unsignedResult; 976 if (failed(reader.parseSignedVarInt(unsignedResult))) 977 return failure(); 978 result = static_cast<int64_t>(unsignedResult); 979 return success(); 980 } 981 982 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 983 // Small values are encoded using a single byte. 984 if (bitWidth <= 8) { 985 uint8_t value; 986 if (failed(reader.parseByte(value))) 987 return failure(); 988 return APInt(bitWidth, value); 989 } 990 991 // Large values up to 64 bits are encoded using a single varint. 992 if (bitWidth <= 64) { 993 uint64_t value; 994 if (failed(reader.parseSignedVarInt(value))) 995 return failure(); 996 return APInt(bitWidth, value); 997 } 998 999 // Otherwise, for really big values we encode the array of active words in 1000 // the value. 1001 uint64_t numActiveWords; 1002 if (failed(reader.parseVarInt(numActiveWords))) 1003 return failure(); 1004 SmallVector<uint64_t, 4> words(numActiveWords); 1005 for (uint64_t i = 0; i < numActiveWords; ++i) 1006 if (failed(reader.parseSignedVarInt(words[i]))) 1007 return failure(); 1008 return APInt(bitWidth, words); 1009 } 1010 1011 FailureOr<APFloat> 1012 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 1013 FailureOr<APInt> intVal = 1014 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 1015 if (failed(intVal)) 1016 return failure(); 1017 return APFloat(semantics, *intVal); 1018 } 1019 1020 LogicalResult readString(StringRef &result) override { 1021 return stringReader.parseString(reader, result); 1022 } 1023 1024 LogicalResult readBlob(ArrayRef<char> &result) override { 1025 uint64_t dataSize; 1026 ArrayRef<uint8_t> data; 1027 if (failed(reader.parseVarInt(dataSize)) || 1028 failed(reader.parseBytes(dataSize, data))) 1029 return failure(); 1030 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 1031 data.size()); 1032 return success(); 1033 } 1034 1035 LogicalResult readBool(bool &result) override { 1036 return reader.parseByte(result); 1037 } 1038 1039 private: 1040 AttrTypeReader &attrTypeReader; 1041 StringSectionReader &stringReader; 1042 ResourceSectionReader &resourceReader; 1043 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 1044 EncodingReader &reader; 1045 uint64_t &bytecodeVersion; 1046 }; 1047 1048 /// Wraps the properties section and handles reading properties out of it. 1049 class PropertiesSectionReader { 1050 public: 1051 /// Initialize the properties section reader with the given section data. 1052 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1053 if (sectionData.empty()) 1054 return success(); 1055 EncodingReader propReader(sectionData, fileLoc); 1056 uint64_t count; 1057 if (failed(propReader.parseVarInt(count))) 1058 return failure(); 1059 // Parse the raw properties buffer. 1060 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1061 return failure(); 1062 1063 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1064 offsetTable.reserve(count); 1065 for (auto idx : llvm::seq<int64_t>(0, count)) { 1066 (void)idx; 1067 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1068 ArrayRef<uint8_t> rawProperties; 1069 uint64_t dataSize; 1070 if (failed(offsetsReader.parseVarInt(dataSize)) || 1071 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1072 return failure(); 1073 } 1074 if (!offsetsReader.empty()) 1075 return offsetsReader.emitError() 1076 << "Broken properties section: didn't exhaust the offsets table"; 1077 return success(); 1078 } 1079 1080 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1081 OperationName *opName, OperationState &opState) { 1082 uint64_t propertiesIdx; 1083 if (failed(dialectReader.readVarInt(propertiesIdx))) 1084 return failure(); 1085 if (propertiesIdx >= offsetTable.size()) 1086 return dialectReader.emitError("Properties idx out-of-bound for ") 1087 << opName->getStringRef(); 1088 size_t propertiesOffset = offsetTable[propertiesIdx]; 1089 if (propertiesIdx >= propertiesBuffers.size()) 1090 return dialectReader.emitError("Properties offset out-of-bound for ") 1091 << opName->getStringRef(); 1092 1093 // Acquire the sub-buffer that represent the requested properties. 1094 ArrayRef<char> rawProperties; 1095 { 1096 // "Seek" to the requested offset by getting a new reader with the right 1097 // sub-buffer. 1098 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1099 fileLoc); 1100 // Properties are stored as a sequence of {size + raw_data}. 1101 if (failed( 1102 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1103 return failure(); 1104 } 1105 // Setup a new reader to read from the `rawProperties` sub-buffer. 1106 EncodingReader reader( 1107 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1108 DialectReader propReader = dialectReader.withEncodingReader(reader); 1109 1110 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1111 if (iface) 1112 return iface->readProperties(propReader, opState); 1113 if (opName->isRegistered()) 1114 return propReader.emitError( 1115 "has properties but missing BytecodeOpInterface for ") 1116 << opName->getStringRef(); 1117 // Unregistered op are storing properties as an attribute. 1118 return propReader.readAttribute(opState.propertiesAttr); 1119 } 1120 1121 private: 1122 /// The properties buffer referenced within the bytecode file. 1123 ArrayRef<uint8_t> propertiesBuffers; 1124 1125 /// Table of offset in the buffer above. 1126 SmallVector<int64_t> offsetTable; 1127 }; 1128 } // namespace 1129 1130 LogicalResult AttrTypeReader::initialize( 1131 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 1132 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { 1133 EncodingReader offsetReader(offsetSectionData, fileLoc); 1134 1135 // Parse the number of attribute and type entries. 1136 uint64_t numAttributes, numTypes; 1137 if (failed(offsetReader.parseVarInt(numAttributes)) || 1138 failed(offsetReader.parseVarInt(numTypes))) 1139 return failure(); 1140 attributes.resize(numAttributes); 1141 types.resize(numTypes); 1142 1143 // A functor used to accumulate the offsets for the entries in the given 1144 // range. 1145 uint64_t currentOffset = 0; 1146 auto parseEntries = [&](auto &&range) { 1147 size_t currentIndex = 0, endIndex = range.size(); 1148 1149 // Parse an individual entry. 1150 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1151 auto &entry = range[currentIndex++]; 1152 1153 uint64_t entrySize; 1154 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1155 entry.hasCustomEncoding))) 1156 return failure(); 1157 1158 // Verify that the offset is actually valid. 1159 if (currentOffset + entrySize > sectionData.size()) { 1160 return offsetReader.emitError( 1161 "Attribute or Type entry offset points past the end of section"); 1162 } 1163 1164 entry.data = sectionData.slice(currentOffset, entrySize); 1165 entry.dialect = dialect; 1166 currentOffset += entrySize; 1167 return success(); 1168 }; 1169 while (currentIndex != endIndex) 1170 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1171 return failure(); 1172 return success(); 1173 }; 1174 1175 // Process each of the attributes, and then the types. 1176 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1177 return failure(); 1178 1179 // Ensure that we read everything from the section. 1180 if (!offsetReader.empty()) { 1181 return offsetReader.emitError( 1182 "unexpected trailing data in the Attribute/Type offset section"); 1183 } 1184 1185 return success(); 1186 } 1187 1188 template <typename T> 1189 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1190 StringRef entryType) { 1191 if (index >= entries.size()) { 1192 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1193 return {}; 1194 } 1195 1196 // If the entry has already been resolved, there is nothing left to do. 1197 Entry<T> &entry = entries[index]; 1198 if (entry.entry) 1199 return entry.entry; 1200 1201 // Parse the entry. 1202 EncodingReader reader(entry.data, fileLoc); 1203 1204 // Parse based on how the entry was encoded. 1205 if (entry.hasCustomEncoding) { 1206 if (failed(parseCustomEntry(entry, reader, entryType))) 1207 return T(); 1208 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1209 return T(); 1210 } 1211 1212 if (!reader.empty()) { 1213 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1214 return T(); 1215 } 1216 return entry.entry; 1217 } 1218 1219 template <typename T> 1220 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1221 StringRef entryType) { 1222 StringRef asmStr; 1223 if (failed(reader.parseNullTerminatedString(asmStr))) 1224 return failure(); 1225 1226 // Invoke the MLIR assembly parser to parse the entry text. 1227 size_t numRead = 0; 1228 MLIRContext *context = fileLoc->getContext(); 1229 if constexpr (std::is_same_v<T, Type>) 1230 result = 1231 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1232 else 1233 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1234 /*isKnownNullTerminated=*/true); 1235 if (!result) 1236 return failure(); 1237 1238 // Ensure there weren't dangling characters after the entry. 1239 if (numRead != asmStr.size()) { 1240 return reader.emitError("trailing characters found after ", entryType, 1241 " assembly format: ", asmStr.drop_front(numRead)); 1242 } 1243 return success(); 1244 } 1245 1246 template <typename T> 1247 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1248 EncodingReader &reader, 1249 StringRef entryType) { 1250 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, 1251 reader, bytecodeVersion); 1252 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1253 return failure(); 1254 1255 if constexpr (std::is_same_v<T, Type>) { 1256 // Try parsing with callbacks first if available. 1257 for (const auto &callback : 1258 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { 1259 if (failed( 1260 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1261 return failure(); 1262 // Early return if parsing was successful. 1263 if (!!entry.entry) 1264 return success(); 1265 1266 // Reset the reader if we failed to parse, so we can fall through the 1267 // other parsing functions. 1268 reader = EncodingReader(entry.data, reader.getLoc()); 1269 } 1270 } else { 1271 // Try parsing with callbacks first if available. 1272 for (const auto &callback : 1273 parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { 1274 if (failed( 1275 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1276 return failure(); 1277 // Early return if parsing was successful. 1278 if (!!entry.entry) 1279 return success(); 1280 1281 // Reset the reader if we failed to parse, so we can fall through the 1282 // other parsing functions. 1283 reader = EncodingReader(entry.data, reader.getLoc()); 1284 } 1285 } 1286 1287 // Ensure that the dialect implements the bytecode interface. 1288 if (!entry.dialect->interface) { 1289 return reader.emitError("dialect '", entry.dialect->name, 1290 "' does not implement the bytecode interface"); 1291 } 1292 1293 if constexpr (std::is_same_v<T, Type>) 1294 entry.entry = entry.dialect->interface->readType(dialectReader); 1295 else 1296 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1297 1298 return success(!!entry.entry); 1299 } 1300 1301 //===----------------------------------------------------------------------===// 1302 // Bytecode Reader 1303 //===----------------------------------------------------------------------===// 1304 1305 /// This class is used to read a bytecode buffer and translate it into MLIR. 1306 class mlir::BytecodeReader::Impl { 1307 struct RegionReadState; 1308 using LazyLoadableOpsInfo = 1309 std::list<std::pair<Operation *, RegionReadState>>; 1310 using LazyLoadableOpsMap = 1311 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1312 1313 public: 1314 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1315 llvm::MemoryBufferRef buffer, 1316 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1317 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1318 attrTypeReader(stringReader, resourceReader, dialectsMap, version, 1319 fileLoc, config), 1320 // Use the builtin unrealized conversion cast operation to represent 1321 // forward references to values that aren't yet defined. 1322 forwardRefOpState(UnknownLoc::get(config.getContext()), 1323 "builtin.unrealized_conversion_cast", ValueRange(), 1324 NoneType::get(config.getContext())), 1325 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1326 1327 /// Read the bytecode defined within `buffer` into the given block. 1328 LogicalResult read(Block *block, 1329 llvm::function_ref<bool(Operation *)> lazyOps); 1330 1331 /// Return the number of ops that haven't been materialized yet. 1332 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1333 1334 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1335 1336 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1337 /// newly found lazy operation. 1338 LogicalResult 1339 materialize(Operation *op, 1340 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1341 this->lazyOpsCallback = lazyOpsCallback; 1342 auto resetlazyOpsCallback = 1343 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1344 auto it = lazyLoadableOpsMap.find(op); 1345 assert(it != lazyLoadableOpsMap.end() && 1346 "materialize called on non-materializable op"); 1347 return materialize(it); 1348 } 1349 1350 /// Materialize all operations. 1351 LogicalResult materializeAll() { 1352 while (!lazyLoadableOpsMap.empty()) { 1353 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1354 return failure(); 1355 } 1356 return success(); 1357 } 1358 1359 /// Finalize the lazy-loading by calling back with every op that hasn't been 1360 /// materialized to let the client decide if the op should be deleted or 1361 /// materialized. The op is materialized if the callback returns true, deleted 1362 /// otherwise. 1363 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1364 while (!lazyLoadableOps.empty()) { 1365 Operation *op = lazyLoadableOps.begin()->first; 1366 if (shouldMaterialize(op)) { 1367 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1368 return failure(); 1369 continue; 1370 } 1371 op->dropAllReferences(); 1372 op->erase(); 1373 lazyLoadableOps.pop_front(); 1374 lazyLoadableOpsMap.erase(op); 1375 } 1376 return success(); 1377 } 1378 1379 private: 1380 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1381 assert(it != lazyLoadableOpsMap.end() && 1382 "materialize called on non-materializable op"); 1383 valueScopes.emplace_back(); 1384 std::vector<RegionReadState> regionStack; 1385 regionStack.push_back(std::move(it->getSecond()->second)); 1386 lazyLoadableOps.erase(it->getSecond()); 1387 lazyLoadableOpsMap.erase(it); 1388 1389 while (!regionStack.empty()) 1390 if (failed(parseRegions(regionStack, regionStack.back()))) 1391 return failure(); 1392 return success(); 1393 } 1394 1395 /// Return the context for this config. 1396 MLIRContext *getContext() const { return config.getContext(); } 1397 1398 /// Parse the bytecode version. 1399 LogicalResult parseVersion(EncodingReader &reader); 1400 1401 //===--------------------------------------------------------------------===// 1402 // Dialect Section 1403 1404 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1405 1406 /// Parse an operation name reference using the given reader, and set the 1407 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1408 /// context where opName was registered. 1409 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1410 std::optional<bool> &wasRegistered); 1411 1412 //===--------------------------------------------------------------------===// 1413 // Attribute/Type Section 1414 1415 /// Parse an attribute or type using the given reader. 1416 template <typename T> 1417 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1418 return attrTypeReader.parseAttribute(reader, result); 1419 } 1420 LogicalResult parseType(EncodingReader &reader, Type &result) { 1421 return attrTypeReader.parseType(reader, result); 1422 } 1423 1424 //===--------------------------------------------------------------------===// 1425 // Resource Section 1426 1427 LogicalResult 1428 parseResourceSection(EncodingReader &reader, 1429 std::optional<ArrayRef<uint8_t>> resourceData, 1430 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1431 1432 //===--------------------------------------------------------------------===// 1433 // IR Section 1434 1435 /// This struct represents the current read state of a range of regions. This 1436 /// struct is used to enable iterative parsing of regions. 1437 struct RegionReadState { 1438 RegionReadState(Operation *op, EncodingReader *reader, 1439 bool isIsolatedFromAbove) 1440 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1441 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1442 bool isIsolatedFromAbove) 1443 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1444 isIsolatedFromAbove(isIsolatedFromAbove) {} 1445 1446 /// The current regions being read. 1447 MutableArrayRef<Region>::iterator curRegion, endRegion; 1448 /// This is the reader to use for this region, this pointer is pointing to 1449 /// the parent region reader unless the current region is IsolatedFromAbove, 1450 /// in which case the pointer is pointing to the `owningReader` which is a 1451 /// section dedicated to the current region. 1452 EncodingReader *reader; 1453 std::unique_ptr<EncodingReader> owningReader; 1454 1455 /// The number of values defined immediately within this region. 1456 unsigned numValues = 0; 1457 1458 /// The current blocks of the region being read. 1459 SmallVector<Block *> curBlocks; 1460 Region::iterator curBlock = {}; 1461 1462 /// The number of operations remaining to be read from the current block 1463 /// being read. 1464 uint64_t numOpsRemaining = 0; 1465 1466 /// A flag indicating if the regions being read are isolated from above. 1467 bool isIsolatedFromAbove = false; 1468 }; 1469 1470 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1471 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1472 RegionReadState &readState); 1473 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1474 RegionReadState &readState, 1475 bool &isIsolatedFromAbove); 1476 1477 LogicalResult parseRegion(RegionReadState &readState); 1478 LogicalResult parseBlockHeader(EncodingReader &reader, 1479 RegionReadState &readState); 1480 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1481 1482 //===--------------------------------------------------------------------===// 1483 // Value Processing 1484 1485 /// Parse an operand reference using the given reader. Returns nullptr in the 1486 /// case of failure. 1487 Value parseOperand(EncodingReader &reader); 1488 1489 /// Sequentially define the given value range. 1490 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1491 1492 /// Create a value to use for a forward reference. 1493 Value createForwardRef(); 1494 1495 //===--------------------------------------------------------------------===// 1496 // Use-list order helpers 1497 1498 /// This struct is a simple storage that contains information required to 1499 /// reorder the use-list of a value with respect to the pre-order traversal 1500 /// ordering. 1501 struct UseListOrderStorage { 1502 UseListOrderStorage(bool isIndexPairEncoding, 1503 SmallVector<unsigned, 4> &&indices) 1504 : indices(std::move(indices)), 1505 isIndexPairEncoding(isIndexPairEncoding){}; 1506 /// The vector containing the information required to reorder the 1507 /// use-list of a value. 1508 SmallVector<unsigned, 4> indices; 1509 1510 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1511 /// indexing, such as `dst = order[src]`. 1512 bool isIndexPairEncoding; 1513 }; 1514 1515 /// Parse use-list order from bytecode for a range of values if available. The 1516 /// range is expected to be either a block argument or an op result range. On 1517 /// success, return a map of the position in the range and the use-list order 1518 /// encoding. The function assumes to know the size of the range it is 1519 /// processing. 1520 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1521 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1522 uint64_t rangeSize); 1523 1524 /// Shuffle the use-chain according to the order parsed. 1525 LogicalResult sortUseListOrder(Value value); 1526 1527 /// Recursively visit all the values defined within topLevelOp and sort the 1528 /// use-list orders according to the indices parsed. 1529 LogicalResult processUseLists(Operation *topLevelOp); 1530 1531 //===--------------------------------------------------------------------===// 1532 // Fields 1533 1534 /// This class represents a single value scope, in which a value scope is 1535 /// delimited by isolated from above regions. 1536 struct ValueScope { 1537 /// Push a new region state onto this scope, reserving enough values for 1538 /// those defined within the current region of the provided state. 1539 void push(RegionReadState &readState) { 1540 nextValueIDs.push_back(values.size()); 1541 values.resize(values.size() + readState.numValues); 1542 } 1543 1544 /// Pop the values defined for the current region within the provided region 1545 /// state. 1546 void pop(RegionReadState &readState) { 1547 values.resize(values.size() - readState.numValues); 1548 nextValueIDs.pop_back(); 1549 } 1550 1551 /// The set of values defined in this scope. 1552 std::vector<Value> values; 1553 1554 /// The ID for the next defined value for each region current being 1555 /// processed in this scope. 1556 SmallVector<unsigned, 4> nextValueIDs; 1557 }; 1558 1559 /// The configuration of the parser. 1560 const ParserConfig &config; 1561 1562 /// A location to use when emitting errors. 1563 Location fileLoc; 1564 1565 /// Flag that indicates if lazyloading is enabled. 1566 bool lazyLoading; 1567 1568 /// Keep track of operations that have been lazy loaded (their regions haven't 1569 /// been materialized), along with the `RegionReadState` that allows to 1570 /// lazy-load the regions nested under the operation. 1571 LazyLoadableOpsInfo lazyLoadableOps; 1572 LazyLoadableOpsMap lazyLoadableOpsMap; 1573 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1574 1575 /// The reader used to process attribute and types within the bytecode. 1576 AttrTypeReader attrTypeReader; 1577 1578 /// The version of the bytecode being read. 1579 uint64_t version = 0; 1580 1581 /// The producer of the bytecode being read. 1582 StringRef producer; 1583 1584 /// The table of IR units referenced within the bytecode file. 1585 SmallVector<std::unique_ptr<BytecodeDialect>> dialects; 1586 llvm::StringMap<BytecodeDialect *> dialectsMap; 1587 SmallVector<BytecodeOperationName> opNames; 1588 1589 /// The reader used to process resources within the bytecode. 1590 ResourceSectionReader resourceReader; 1591 1592 /// Worklist of values with custom use-list orders to process before the end 1593 /// of the parsing. 1594 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1595 1596 /// The table of strings referenced within the bytecode file. 1597 StringSectionReader stringReader; 1598 1599 /// The table of properties referenced by the operation in the bytecode file. 1600 PropertiesSectionReader propertiesReader; 1601 1602 /// The current set of available IR value scopes. 1603 std::vector<ValueScope> valueScopes; 1604 1605 /// The global pre-order operation ordering. 1606 DenseMap<Operation *, unsigned> operationIDs; 1607 1608 /// A block containing the set of operations defined to create forward 1609 /// references. 1610 Block forwardRefOps; 1611 1612 /// A block containing previously created, and no longer used, forward 1613 /// reference operations. 1614 Block openForwardRefOps; 1615 1616 /// An operation state used when instantiating forward references. 1617 OperationState forwardRefOpState; 1618 1619 /// Reference to the input buffer. 1620 llvm::MemoryBufferRef buffer; 1621 1622 /// The optional owning source manager, which when present may be used to 1623 /// extend the lifetime of the input buffer. 1624 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1625 }; 1626 1627 LogicalResult BytecodeReader::Impl::read( 1628 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1629 EncodingReader reader(buffer.getBuffer(), fileLoc); 1630 this->lazyOpsCallback = lazyOpsCallback; 1631 auto resetlazyOpsCallback = 1632 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1633 1634 // Skip over the bytecode header, this should have already been checked. 1635 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1636 return failure(); 1637 // Parse the bytecode version and producer. 1638 if (failed(parseVersion(reader)) || 1639 failed(reader.parseNullTerminatedString(producer))) 1640 return failure(); 1641 1642 // Add a diagnostic handler that attaches a note that includes the original 1643 // producer of the bytecode. 1644 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1645 diag.attachNote() << "in bytecode version " << version 1646 << " produced by: " << producer; 1647 return failure(); 1648 }); 1649 1650 // Parse the raw data for each of the top-level sections of the bytecode. 1651 std::optional<ArrayRef<uint8_t>> 1652 sectionDatas[bytecode::Section::kNumSections]; 1653 while (!reader.empty()) { 1654 // Read the next section from the bytecode. 1655 bytecode::Section::ID sectionID; 1656 ArrayRef<uint8_t> sectionData; 1657 if (failed(reader.parseSection(sectionID, sectionData))) 1658 return failure(); 1659 1660 // Check for duplicate sections, we only expect one instance of each. 1661 if (sectionDatas[sectionID]) { 1662 return reader.emitError("duplicate top-level section: ", 1663 ::toString(sectionID)); 1664 } 1665 sectionDatas[sectionID] = sectionData; 1666 } 1667 // Check that all of the required sections were found. 1668 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1669 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1670 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1671 return reader.emitError("missing data for top-level section: ", 1672 ::toString(sectionID)); 1673 } 1674 } 1675 1676 // Process the string section first. 1677 if (failed(stringReader.initialize( 1678 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1679 return failure(); 1680 1681 // Process the properties section. 1682 if (sectionDatas[bytecode::Section::kProperties] && 1683 failed(propertiesReader.initialize( 1684 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1685 return failure(); 1686 1687 // Process the dialect section. 1688 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1689 return failure(); 1690 1691 // Process the resource section if present. 1692 if (failed(parseResourceSection( 1693 reader, sectionDatas[bytecode::Section::kResource], 1694 sectionDatas[bytecode::Section::kResourceOffset]))) 1695 return failure(); 1696 1697 // Process the attribute and type section. 1698 if (failed(attrTypeReader.initialize( 1699 dialects, *sectionDatas[bytecode::Section::kAttrType], 1700 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1701 return failure(); 1702 1703 // Finally, process the IR section. 1704 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1705 } 1706 1707 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1708 if (failed(reader.parseVarInt(version))) 1709 return failure(); 1710 1711 // Validate the bytecode version. 1712 uint64_t currentVersion = bytecode::kVersion; 1713 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1714 if (version < minSupportedVersion) { 1715 return reader.emitError("bytecode version ", version, 1716 " is older than the current version of ", 1717 currentVersion, ", and upgrade is not supported"); 1718 } 1719 if (version > currentVersion) { 1720 return reader.emitError("bytecode version ", version, 1721 " is newer than the current version ", 1722 currentVersion); 1723 } 1724 // Override any request to lazy-load if the bytecode version is too old. 1725 if (version < bytecode::kLazyLoading) 1726 lazyLoading = false; 1727 return success(); 1728 } 1729 1730 //===----------------------------------------------------------------------===// 1731 // Dialect Section 1732 1733 LogicalResult BytecodeDialect::load(const DialectReader &reader, 1734 MLIRContext *ctx) { 1735 if (dialect) 1736 return success(); 1737 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1738 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1739 return reader.emitError("dialect '") 1740 << name 1741 << "' is unknown. If this is intended, please call " 1742 "allowUnregisteredDialects() on the MLIRContext, or use " 1743 "-allow-unregistered-dialect with the MLIR tool used."; 1744 } 1745 dialect = loadedDialect; 1746 1747 // If the dialect was actually loaded, check to see if it has a bytecode 1748 // interface. 1749 if (loadedDialect) 1750 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1751 if (!versionBuffer.empty()) { 1752 if (!interface) 1753 return reader.emitError("dialect '") 1754 << name 1755 << "' does not implement the bytecode interface, " 1756 "but found a version entry"; 1757 EncodingReader encReader(versionBuffer, reader.getLoc()); 1758 DialectReader versionReader = reader.withEncodingReader(encReader); 1759 loadedVersion = interface->readVersion(versionReader); 1760 if (!loadedVersion) 1761 return failure(); 1762 } 1763 return success(); 1764 } 1765 1766 LogicalResult 1767 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1768 EncodingReader sectionReader(sectionData, fileLoc); 1769 1770 // Parse the number of dialects in the section. 1771 uint64_t numDialects; 1772 if (failed(sectionReader.parseVarInt(numDialects))) 1773 return failure(); 1774 dialects.resize(numDialects); 1775 1776 // Parse each of the dialects. 1777 for (uint64_t i = 0; i < numDialects; ++i) { 1778 dialects[i] = std::make_unique<BytecodeDialect>(); 1779 /// Before version kDialectVersioning, there wasn't any versioning available 1780 /// for dialects, and the entryIdx represent the string itself. 1781 if (version < bytecode::kDialectVersioning) { 1782 if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) 1783 return failure(); 1784 continue; 1785 } 1786 1787 // Parse ID representing dialect and version. 1788 uint64_t dialectNameIdx; 1789 bool versionAvailable; 1790 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1791 versionAvailable))) 1792 return failure(); 1793 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1794 dialects[i]->name))) 1795 return failure(); 1796 if (versionAvailable) { 1797 bytecode::Section::ID sectionID; 1798 if (failed(sectionReader.parseSection(sectionID, 1799 dialects[i]->versionBuffer))) 1800 return failure(); 1801 if (sectionID != bytecode::Section::kDialectVersions) { 1802 emitError(fileLoc, "expected dialect version section"); 1803 return failure(); 1804 } 1805 } 1806 dialectsMap[dialects[i]->name] = dialects[i].get(); 1807 } 1808 1809 // Parse the operation names, which are grouped by dialect. 1810 auto parseOpName = [&](BytecodeDialect *dialect) { 1811 StringRef opName; 1812 std::optional<bool> wasRegistered; 1813 // Prior to version kNativePropertiesEncoding, the information about wheter 1814 // an op was registered or not wasn't encoded. 1815 if (version < bytecode::kNativePropertiesEncoding) { 1816 if (failed(stringReader.parseString(sectionReader, opName))) 1817 return failure(); 1818 } else { 1819 bool wasRegisteredFlag; 1820 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1821 wasRegisteredFlag))) 1822 return failure(); 1823 wasRegistered = wasRegisteredFlag; 1824 } 1825 opNames.emplace_back(dialect, opName, wasRegistered); 1826 return success(); 1827 }; 1828 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation 1829 // where the number of ops are known. 1830 if (version >= bytecode::kElideUnknownBlockArgLocation) { 1831 uint64_t numOps; 1832 if (failed(sectionReader.parseVarInt(numOps))) 1833 return failure(); 1834 opNames.reserve(numOps); 1835 } 1836 while (!sectionReader.empty()) 1837 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1838 return failure(); 1839 return success(); 1840 } 1841 1842 FailureOr<OperationName> 1843 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1844 std::optional<bool> &wasRegistered) { 1845 BytecodeOperationName *opName = nullptr; 1846 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1847 return failure(); 1848 wasRegistered = opName->wasRegistered; 1849 // Check to see if this operation name has already been resolved. If we 1850 // haven't, load the dialect and build the operation name. 1851 if (!opName->opName) { 1852 // Load the dialect and its version. 1853 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1854 dialectsMap, reader, version); 1855 if (failed(opName->dialect->load(dialectReader, getContext()))) 1856 return failure(); 1857 // If the opName is empty, this is because we use to accept names such as 1858 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1859 // format anymore but for now we'll be backward compatible. This can only 1860 // happen with unregistered dialects. 1861 if (opName->name.empty()) { 1862 if (opName->dialect->getLoadedDialect()) 1863 return emitError(fileLoc) << "has an empty opname for dialect '" 1864 << opName->dialect->name << "'\n"; 1865 1866 opName->opName.emplace(opName->dialect->name, getContext()); 1867 } else { 1868 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1869 getContext()); 1870 } 1871 } 1872 return *opName->opName; 1873 } 1874 1875 //===----------------------------------------------------------------------===// 1876 // Resource Section 1877 1878 LogicalResult BytecodeReader::Impl::parseResourceSection( 1879 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1880 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1881 // Ensure both sections are either present or not. 1882 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1883 if (resourceOffsetData) 1884 return emitError(fileLoc, "unexpected resource offset section when " 1885 "resource section is not present"); 1886 return emitError( 1887 fileLoc, 1888 "expected resource offset section when resource section is present"); 1889 } 1890 1891 // If the resource sections are absent, there is nothing to do. 1892 if (!resourceData) 1893 return success(); 1894 1895 // Initialize the resource reader with the resource sections. 1896 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1897 dialectsMap, reader, version); 1898 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1899 *resourceData, *resourceOffsetData, 1900 dialectReader, bufferOwnerRef); 1901 } 1902 1903 //===----------------------------------------------------------------------===// 1904 // UseListOrder Helpers 1905 1906 FailureOr<BytecodeReader::Impl::UseListMapT> 1907 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1908 uint64_t numResults) { 1909 BytecodeReader::Impl::UseListMapT map; 1910 uint64_t numValuesToRead = 1; 1911 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1912 return failure(); 1913 1914 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1915 uint64_t resultIdx = 0; 1916 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1917 return failure(); 1918 1919 uint64_t numValues; 1920 bool indexPairEncoding; 1921 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1922 return failure(); 1923 1924 SmallVector<unsigned, 4> useListOrders; 1925 for (size_t idx = 0; idx < numValues; idx++) { 1926 uint64_t index; 1927 if (failed(reader.parseVarInt(index))) 1928 return failure(); 1929 useListOrders.push_back(index); 1930 } 1931 1932 // Store in a map the result index 1933 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1934 std::move(useListOrders))); 1935 } 1936 1937 return map; 1938 } 1939 1940 /// Sorts each use according to the order specified in the use-list parsed. If 1941 /// the custom use-list is not found, this means that the order needs to be 1942 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1943 /// on the same operation, the order will follow the reverse operand number 1944 /// ordering. 1945 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1946 // Early return for trivial use-lists. 1947 if (value.use_empty() || value.hasOneUse()) 1948 return success(); 1949 1950 bool hasIncomingOrder = 1951 valueToUseListMap.contains(value.getAsOpaquePointer()); 1952 1953 // Compute the current order of the use-list with respect to the global 1954 // ordering. Detect if the order is already sorted while doing so. 1955 bool alreadySorted = true; 1956 auto &firstUse = *value.use_begin(); 1957 uint64_t prevID = 1958 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1959 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1960 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1961 uint64_t currentID = bytecode::getUseID( 1962 item.value(), operationIDs.at(item.value().getOwner())); 1963 alreadySorted &= prevID > currentID; 1964 currentOrder.push_back({item.index(), currentID}); 1965 prevID = currentID; 1966 } 1967 1968 // If the order is already sorted, and there wasn't a custom order to apply 1969 // from the bytecode file, we are done. 1970 if (alreadySorted && !hasIncomingOrder) 1971 return success(); 1972 1973 // If not already sorted, sort the indices of the current order by descending 1974 // useIDs. 1975 if (!alreadySorted) 1976 std::sort( 1977 currentOrder.begin(), currentOrder.end(), 1978 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1979 1980 if (!hasIncomingOrder) { 1981 // If the bytecode file did not contain any custom use-list order, it means 1982 // that the order was descending useID. Hence, shuffle by the first index 1983 // of the `currentOrder` pair. 1984 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1985 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1986 value.shuffleUseList(shuffle); 1987 return success(); 1988 } 1989 1990 // Pull the custom order info from the map. 1991 UseListOrderStorage customOrder = 1992 valueToUseListMap.at(value.getAsOpaquePointer()); 1993 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1994 uint64_t numUses = 1995 std::distance(value.getUses().begin(), value.getUses().end()); 1996 1997 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1998 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1999 // as identity, and then apply the mapping encoded in the indices. 2000 if (customOrder.isIndexPairEncoding) { 2001 // Return failure if the number of indices was not representing pairs. 2002 if (shuffle.size() & 1) 2003 return failure(); 2004 2005 SmallVector<unsigned, 4> newShuffle(numUses); 2006 size_t idx = 0; 2007 std::iota(newShuffle.begin(), newShuffle.end(), idx); 2008 for (idx = 0; idx < shuffle.size(); idx += 2) 2009 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 2010 2011 shuffle = std::move(newShuffle); 2012 } 2013 2014 // Make sure that the indices represent a valid mapping. That is, the sum of 2015 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 2016 // duplicates are allowed in the list. 2017 DenseSet<unsigned> set; 2018 uint64_t accumulator = 0; 2019 for (const auto &elem : shuffle) { 2020 if (set.contains(elem)) 2021 return failure(); 2022 accumulator += elem; 2023 set.insert(elem); 2024 } 2025 if (numUses != shuffle.size() || 2026 accumulator != (((numUses - 1) * numUses) >> 1)) 2027 return failure(); 2028 2029 // Apply the current ordering map onto the shuffle vector to get the final 2030 // use-list sorting indices before shuffling. 2031 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 2032 currentOrder, [&](auto item) { return shuffle[item.first]; })); 2033 value.shuffleUseList(shuffle); 2034 return success(); 2035 } 2036 2037 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 2038 // Precompute operation IDs according to the pre-order walk of the IR. We 2039 // can't do this while parsing since parseRegions ordering is not strictly 2040 // equal to the pre-order walk. 2041 unsigned operationID = 0; 2042 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 2043 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 2044 2045 auto blockWalk = topLevelOp->walk([this](Block *block) { 2046 for (auto arg : block->getArguments()) 2047 if (failed(sortUseListOrder(arg))) 2048 return WalkResult::interrupt(); 2049 return WalkResult::advance(); 2050 }); 2051 2052 auto resultWalk = topLevelOp->walk([this](Operation *op) { 2053 for (auto result : op->getResults()) 2054 if (failed(sortUseListOrder(result))) 2055 return WalkResult::interrupt(); 2056 return WalkResult::advance(); 2057 }); 2058 2059 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 2060 } 2061 2062 //===----------------------------------------------------------------------===// 2063 // IR Section 2064 2065 LogicalResult 2066 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 2067 Block *block) { 2068 EncodingReader reader(sectionData, fileLoc); 2069 2070 // A stack of operation regions currently being read from the bytecode. 2071 std::vector<RegionReadState> regionStack; 2072 2073 // Parse the top-level block using a temporary module operation. 2074 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 2075 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 2076 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 2077 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 2078 if (failed(parseBlockHeader(reader, regionStack.back()))) 2079 return failure(); 2080 valueScopes.emplace_back(); 2081 valueScopes.back().push(regionStack.back()); 2082 2083 // Iteratively parse regions until everything has been resolved. 2084 while (!regionStack.empty()) 2085 if (failed(parseRegions(regionStack, regionStack.back()))) 2086 return failure(); 2087 if (!forwardRefOps.empty()) { 2088 return reader.emitError( 2089 "not all forward unresolved forward operand references"); 2090 } 2091 2092 // Sort use-lists according to what specified in bytecode. 2093 if (failed(processUseLists(*moduleOp))) 2094 return reader.emitError( 2095 "parsed use-list orders were invalid and could not be applied"); 2096 2097 // Resolve dialect version. 2098 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { 2099 // Parsing is complete, give an opportunity to each dialect to visit the 2100 // IR and perform upgrades. 2101 if (!byteCodeDialect->loadedVersion) 2102 continue; 2103 if (byteCodeDialect->interface && 2104 failed(byteCodeDialect->interface->upgradeFromVersion( 2105 *moduleOp, *byteCodeDialect->loadedVersion))) 2106 return failure(); 2107 } 2108 2109 // Verify that the parsed operations are valid. 2110 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2111 return failure(); 2112 2113 // Splice the parsed operations over to the provided top-level block. 2114 auto &parsedOps = moduleOp->getBody()->getOperations(); 2115 auto &destOps = block->getOperations(); 2116 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2117 return success(); 2118 } 2119 2120 LogicalResult 2121 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2122 RegionReadState &readState) { 2123 // Process regions, blocks, and operations until the end or if a nested 2124 // region is encountered. In this case we push a new state in regionStack and 2125 // return, the processing of the current region will resume afterward. 2126 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2127 // If the current block hasn't been setup yet, parse the header for this 2128 // region. The current block is already setup when this function was 2129 // interrupted to recurse down in a nested region and we resume the current 2130 // block after processing the nested region. 2131 if (readState.curBlock == Region::iterator()) { 2132 if (failed(parseRegion(readState))) 2133 return failure(); 2134 2135 // If the region is empty, there is nothing to more to do. 2136 if (readState.curRegion->empty()) 2137 continue; 2138 } 2139 2140 // Parse the blocks within the region. 2141 EncodingReader &reader = *readState.reader; 2142 do { 2143 while (readState.numOpsRemaining--) { 2144 // Read in the next operation. We don't read its regions directly, we 2145 // handle those afterwards as necessary. 2146 bool isIsolatedFromAbove = false; 2147 FailureOr<Operation *> op = 2148 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2149 if (failed(op)) 2150 return failure(); 2151 2152 // If the op has regions, add it to the stack for processing and return: 2153 // we stop the processing of the current region and resume it after the 2154 // inner one is completed. Unless LazyLoading is activated in which case 2155 // nested region parsing is delayed. 2156 if ((*op)->getNumRegions()) { 2157 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2158 2159 // Isolated regions are encoded as a section in version 2 and above. 2160 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { 2161 bytecode::Section::ID sectionID; 2162 ArrayRef<uint8_t> sectionData; 2163 if (failed(reader.parseSection(sectionID, sectionData))) 2164 return failure(); 2165 if (sectionID != bytecode::Section::kIR) 2166 return emitError(fileLoc, "expected IR section for region"); 2167 childState.owningReader = 2168 std::make_unique<EncodingReader>(sectionData, fileLoc); 2169 childState.reader = childState.owningReader.get(); 2170 2171 // If the user has a callback set, they have the opportunity to 2172 // control lazyloading as we go. 2173 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) { 2174 lazyLoadableOps.emplace_back(*op, std::move(childState)); 2175 lazyLoadableOpsMap.try_emplace(*op, 2176 std::prev(lazyLoadableOps.end())); 2177 continue; 2178 } 2179 } 2180 regionStack.push_back(std::move(childState)); 2181 2182 // If the op is isolated from above, push a new value scope. 2183 if (isIsolatedFromAbove) 2184 valueScopes.emplace_back(); 2185 return success(); 2186 } 2187 } 2188 2189 // Move to the next block of the region. 2190 if (++readState.curBlock == readState.curRegion->end()) 2191 break; 2192 if (failed(parseBlockHeader(reader, readState))) 2193 return failure(); 2194 } while (true); 2195 2196 // Reset the current block and any values reserved for this region. 2197 readState.curBlock = {}; 2198 valueScopes.back().pop(readState); 2199 } 2200 2201 // When the regions have been fully parsed, pop them off of the read stack. If 2202 // the regions were isolated from above, we also pop the last value scope. 2203 if (readState.isIsolatedFromAbove) { 2204 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2205 valueScopes.pop_back(); 2206 } 2207 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2208 regionStack.pop_back(); 2209 return success(); 2210 } 2211 2212 FailureOr<Operation *> 2213 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2214 RegionReadState &readState, 2215 bool &isIsolatedFromAbove) { 2216 // Parse the name of the operation. 2217 std::optional<bool> wasRegistered; 2218 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2219 if (failed(opName)) 2220 return failure(); 2221 2222 // Parse the operation mask, which indicates which components of the operation 2223 // are present. 2224 uint8_t opMask; 2225 if (failed(reader.parseByte(opMask))) 2226 return failure(); 2227 2228 /// Parse the location. 2229 LocationAttr opLoc; 2230 if (failed(parseAttribute(reader, opLoc))) 2231 return failure(); 2232 2233 // With the location and name resolved, we can start building the operation 2234 // state. 2235 OperationState opState(opLoc, *opName); 2236 2237 // Parse the attributes of the operation. 2238 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2239 DictionaryAttr dictAttr; 2240 if (failed(parseAttribute(reader, dictAttr))) 2241 return failure(); 2242 opState.attributes = dictAttr; 2243 } 2244 2245 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2246 // kHasProperties wasn't emitted in older bytecode, we should never get 2247 // there without also having the `wasRegistered` flag available. 2248 if (!wasRegistered) 2249 return emitError(fileLoc, 2250 "Unexpected missing `wasRegistered` opname flag at " 2251 "bytecode version ") 2252 << version << " with properties."; 2253 // When an operation is emitted without being registered, the properties are 2254 // stored as an attribute. Otherwise the op must implement the bytecode 2255 // interface and control the serialization. 2256 if (wasRegistered) { 2257 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2258 dialectsMap, reader, version); 2259 if (failed( 2260 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2261 return failure(); 2262 } else { 2263 // If the operation wasn't registered when it was emitted, the properties 2264 // was serialized as an attribute. 2265 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2266 return failure(); 2267 } 2268 } 2269 2270 /// Parse the results of the operation. 2271 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2272 uint64_t numResults; 2273 if (failed(reader.parseVarInt(numResults))) 2274 return failure(); 2275 opState.types.resize(numResults); 2276 for (int i = 0, e = numResults; i < e; ++i) 2277 if (failed(parseType(reader, opState.types[i]))) 2278 return failure(); 2279 } 2280 2281 /// Parse the operands of the operation. 2282 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2283 uint64_t numOperands; 2284 if (failed(reader.parseVarInt(numOperands))) 2285 return failure(); 2286 opState.operands.resize(numOperands); 2287 for (int i = 0, e = numOperands; i < e; ++i) 2288 if (!(opState.operands[i] = parseOperand(reader))) 2289 return failure(); 2290 } 2291 2292 /// Parse the successors of the operation. 2293 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2294 uint64_t numSuccs; 2295 if (failed(reader.parseVarInt(numSuccs))) 2296 return failure(); 2297 opState.successors.resize(numSuccs); 2298 for (int i = 0, e = numSuccs; i < e; ++i) { 2299 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2300 "successor"))) 2301 return failure(); 2302 } 2303 } 2304 2305 /// Parse the use-list orders for the results of the operation. Use-list 2306 /// orders are available since version 3 of the bytecode. 2307 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2308 if (version >= bytecode::kUseListOrdering && 2309 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2310 size_t numResults = opState.types.size(); 2311 auto parseResult = parseUseListOrderForRange(reader, numResults); 2312 if (failed(parseResult)) 2313 return failure(); 2314 resultIdxToUseListMap = std::move(*parseResult); 2315 } 2316 2317 /// Parse the regions of the operation. 2318 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2319 uint64_t numRegions; 2320 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2321 return failure(); 2322 2323 opState.regions.reserve(numRegions); 2324 for (int i = 0, e = numRegions; i < e; ++i) 2325 opState.regions.push_back(std::make_unique<Region>()); 2326 } 2327 2328 // Create the operation at the back of the current block. 2329 Operation *op = Operation::create(opState); 2330 readState.curBlock->push_back(op); 2331 2332 // If the operation had results, update the value references. 2333 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2334 return failure(); 2335 2336 /// Store a map for every value that received a custom use-list order from the 2337 /// bytecode file. 2338 if (resultIdxToUseListMap.has_value()) { 2339 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2340 if (resultIdxToUseListMap->contains(idx)) { 2341 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2342 resultIdxToUseListMap->at(idx)); 2343 } 2344 } 2345 } 2346 return op; 2347 } 2348 2349 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2350 EncodingReader &reader = *readState.reader; 2351 2352 // Parse the number of blocks in the region. 2353 uint64_t numBlocks; 2354 if (failed(reader.parseVarInt(numBlocks))) 2355 return failure(); 2356 2357 // If the region is empty, there is nothing else to do. 2358 if (numBlocks == 0) 2359 return success(); 2360 2361 // Parse the number of values defined in this region. 2362 uint64_t numValues; 2363 if (failed(reader.parseVarInt(numValues))) 2364 return failure(); 2365 readState.numValues = numValues; 2366 2367 // Create the blocks within this region. We do this before processing so that 2368 // we can rely on the blocks existing when creating operations. 2369 readState.curBlocks.clear(); 2370 readState.curBlocks.reserve(numBlocks); 2371 for (uint64_t i = 0; i < numBlocks; ++i) { 2372 readState.curBlocks.push_back(new Block()); 2373 readState.curRegion->push_back(readState.curBlocks.back()); 2374 } 2375 2376 // Prepare the current value scope for this region. 2377 valueScopes.back().push(readState); 2378 2379 // Parse the entry block of the region. 2380 readState.curBlock = readState.curRegion->begin(); 2381 return parseBlockHeader(reader, readState); 2382 } 2383 2384 LogicalResult 2385 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2386 RegionReadState &readState) { 2387 bool hasArgs; 2388 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2389 return failure(); 2390 2391 // Parse the arguments of the block. 2392 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2393 return failure(); 2394 2395 // Uselist orders are available since version 3 of the bytecode. 2396 if (version < bytecode::kUseListOrdering) 2397 return success(); 2398 2399 uint8_t hasUseListOrders = 0; 2400 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2401 return failure(); 2402 2403 if (!hasUseListOrders) 2404 return success(); 2405 2406 Block &blk = *readState.curBlock; 2407 auto argIdxToUseListMap = 2408 parseUseListOrderForRange(reader, blk.getNumArguments()); 2409 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2410 return failure(); 2411 2412 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2413 if (argIdxToUseListMap->contains(idx)) 2414 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2415 argIdxToUseListMap->at(idx)); 2416 2417 // We don't parse the operations of the block here, that's done elsewhere. 2418 return success(); 2419 } 2420 2421 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2422 Block *block) { 2423 // Parse the value ID for the first argument, and the number of arguments. 2424 uint64_t numArgs; 2425 if (failed(reader.parseVarInt(numArgs))) 2426 return failure(); 2427 2428 SmallVector<Type> argTypes; 2429 SmallVector<Location> argLocs; 2430 argTypes.reserve(numArgs); 2431 argLocs.reserve(numArgs); 2432 2433 Location unknownLoc = UnknownLoc::get(config.getContext()); 2434 while (numArgs--) { 2435 Type argType; 2436 LocationAttr argLoc = unknownLoc; 2437 if (version >= bytecode::kElideUnknownBlockArgLocation) { 2438 // Parse the type with hasLoc flag to determine if it has type. 2439 uint64_t typeIdx; 2440 bool hasLoc; 2441 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2442 !(argType = attrTypeReader.resolveType(typeIdx))) 2443 return failure(); 2444 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2445 return failure(); 2446 } else { 2447 // All args has type and location. 2448 if (failed(parseType(reader, argType)) || 2449 failed(parseAttribute(reader, argLoc))) 2450 return failure(); 2451 } 2452 argTypes.push_back(argType); 2453 argLocs.push_back(argLoc); 2454 } 2455 block->addArguments(argTypes, argLocs); 2456 return defineValues(reader, block->getArguments()); 2457 } 2458 2459 //===----------------------------------------------------------------------===// 2460 // Value Processing 2461 2462 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2463 std::vector<Value> &values = valueScopes.back().values; 2464 Value *value = nullptr; 2465 if (failed(parseEntry(reader, values, value, "value"))) 2466 return Value(); 2467 2468 // Create a new forward reference if necessary. 2469 if (!*value) 2470 *value = createForwardRef(); 2471 return *value; 2472 } 2473 2474 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2475 ValueRange newValues) { 2476 ValueScope &valueScope = valueScopes.back(); 2477 std::vector<Value> &values = valueScope.values; 2478 2479 unsigned &valueID = valueScope.nextValueIDs.back(); 2480 unsigned valueIDEnd = valueID + newValues.size(); 2481 if (valueIDEnd > values.size()) { 2482 return reader.emitError( 2483 "value index range was outside of the expected range for " 2484 "the parent region, got [", 2485 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2486 values.size() - 1); 2487 } 2488 2489 // Assign the values and update any forward references. 2490 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2491 Value newValue = newValues[i]; 2492 2493 // Check to see if a definition for this value already exists. 2494 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2495 Operation *forwardRefOp = oldValue.getDefiningOp(); 2496 2497 // Assert that this is a forward reference operation. Given how we compute 2498 // definition ids (incrementally as we parse), it shouldn't be possible 2499 // for the value to be defined any other way. 2500 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2501 "value index was already defined?"); 2502 2503 oldValue.replaceAllUsesWith(newValue); 2504 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2505 } 2506 } 2507 return success(); 2508 } 2509 2510 Value BytecodeReader::Impl::createForwardRef() { 2511 // Check for an avaliable existing operation to use. Otherwise, create a new 2512 // fake operation to use for the reference. 2513 if (!openForwardRefOps.empty()) { 2514 Operation *op = &openForwardRefOps.back(); 2515 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2516 } else { 2517 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2518 } 2519 return forwardRefOps.back().getResult(0); 2520 } 2521 2522 //===----------------------------------------------------------------------===// 2523 // Entry Points 2524 //===----------------------------------------------------------------------===// 2525 2526 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2527 2528 BytecodeReader::BytecodeReader( 2529 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2530 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2531 Location sourceFileLoc = 2532 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2533 /*line=*/0, /*column=*/0); 2534 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2535 bufferOwnerRef); 2536 } 2537 2538 LogicalResult BytecodeReader::readTopLevel( 2539 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2540 return impl->read(block, lazyOpsCallback); 2541 } 2542 2543 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2544 return impl->getNumOpsToMaterialize(); 2545 } 2546 2547 bool BytecodeReader::isMaterializable(Operation *op) { 2548 return impl->isMaterializable(op); 2549 } 2550 2551 LogicalResult BytecodeReader::materialize( 2552 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2553 return impl->materialize(op, lazyOpsCallback); 2554 } 2555 2556 LogicalResult 2557 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2558 return impl->finalize(shouldMaterialize); 2559 } 2560 2561 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2562 return buffer.getBuffer().startswith("ML\xefR"); 2563 } 2564 2565 /// Read the bytecode from the provided memory buffer reference. 2566 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2567 /// and may be used to extend the lifetime of the buffer. 2568 static LogicalResult 2569 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2570 const ParserConfig &config, 2571 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2572 Location sourceFileLoc = 2573 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2574 /*line=*/0, /*column=*/0); 2575 if (!isBytecode(buffer)) { 2576 return emitError(sourceFileLoc, 2577 "input buffer is not an MLIR bytecode file"); 2578 } 2579 2580 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2581 buffer, bufferOwnerRef); 2582 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2583 } 2584 2585 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2586 const ParserConfig &config) { 2587 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2588 } 2589 LogicalResult 2590 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2591 Block *block, const ParserConfig &config) { 2592 return readBytecodeFileImpl( 2593 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2594 sourceMgr); 2595 } 2596