1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // TODO: Support for big-endian architectures. 10 11 #include "mlir/Bytecode/BytecodeReader.h" 12 #include "mlir/AsmParser/AsmParser.h" 13 #include "mlir/Bytecode/BytecodeImplementation.h" 14 #include "mlir/Bytecode/BytecodeOpInterface.h" 15 #include "mlir/Bytecode/Encoding.h" 16 #include "mlir/IR/BuiltinDialect.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/OpImplementation.h" 20 #include "mlir/IR/Verifier.h" 21 #include "mlir/IR/Visitors.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Support/LogicalResult.h" 24 #include "llvm/ADT/ArrayRef.h" 25 #include "llvm/ADT/MapVector.h" 26 #include "llvm/ADT/ScopeExit.h" 27 #include "llvm/ADT/SmallString.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/Support/MemoryBufferRef.h" 31 #include "llvm/Support/SaveAndRestore.h" 32 #include "llvm/Support/SourceMgr.h" 33 #include <cstddef> 34 #include <list> 35 #include <memory> 36 #include <numeric> 37 #include <optional> 38 39 #define DEBUG_TYPE "mlir-bytecode-reader" 40 41 using namespace mlir; 42 43 /// Stringify the given section ID. 44 static std::string toString(bytecode::Section::ID sectionID) { 45 switch (sectionID) { 46 case bytecode::Section::kString: 47 return "String (0)"; 48 case bytecode::Section::kDialect: 49 return "Dialect (1)"; 50 case bytecode::Section::kAttrType: 51 return "AttrType (2)"; 52 case bytecode::Section::kAttrTypeOffset: 53 return "AttrTypeOffset (3)"; 54 case bytecode::Section::kIR: 55 return "IR (4)"; 56 case bytecode::Section::kResource: 57 return "Resource (5)"; 58 case bytecode::Section::kResourceOffset: 59 return "ResourceOffset (6)"; 60 case bytecode::Section::kDialectVersions: 61 return "DialectVersions (7)"; 62 case bytecode::Section::kProperties: 63 return "Properties (8)"; 64 default: 65 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 66 } 67 } 68 69 /// Returns true if the given top-level section ID is optional. 70 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 71 switch (sectionID) { 72 case bytecode::Section::kString: 73 case bytecode::Section::kDialect: 74 case bytecode::Section::kAttrType: 75 case bytecode::Section::kAttrTypeOffset: 76 case bytecode::Section::kIR: 77 return false; 78 case bytecode::Section::kResource: 79 case bytecode::Section::kResourceOffset: 80 case bytecode::Section::kDialectVersions: 81 return true; 82 case bytecode::Section::kProperties: 83 return version < 5; 84 default: 85 llvm_unreachable("unknown section ID"); 86 } 87 } 88 89 //===----------------------------------------------------------------------===// 90 // EncodingReader 91 //===----------------------------------------------------------------------===// 92 93 namespace { 94 class EncodingReader { 95 public: 96 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 97 : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} 98 explicit EncodingReader(StringRef contents, Location fileLoc) 99 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 100 contents.size()}, 101 fileLoc) {} 102 103 /// Returns true if the entire section has been read. 104 bool empty() const { return dataIt == dataEnd; } 105 106 /// Returns the remaining size of the bytecode. 107 size_t size() const { return dataEnd - dataIt; } 108 109 /// Align the current reader position to the specified alignment. 110 LogicalResult alignTo(unsigned alignment) { 111 if (!llvm::isPowerOf2_32(alignment)) 112 return emitError("expected alignment to be a power-of-two"); 113 114 // Shift the reader position to the next alignment boundary. 115 while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { 116 uint8_t padding; 117 if (failed(parseByte(padding))) 118 return failure(); 119 if (padding != bytecode::kAlignmentByte) { 120 return emitError("expected alignment byte (0xCB), but got: '0x" + 121 llvm::utohexstr(padding) + "'"); 122 } 123 } 124 125 // Ensure the data iterator is now aligned. This case is unlikely because we 126 // *just* went through the effort to align the data iterator. 127 if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) { 128 return emitError("expected data iterator aligned to ", alignment, 129 ", but got pointer: '0x" + 130 llvm::utohexstr((uintptr_t)dataIt) + "'"); 131 } 132 133 return success(); 134 } 135 136 /// Emit an error using the given arguments. 137 template <typename... Args> 138 InFlightDiagnostic emitError(Args &&...args) const { 139 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 140 } 141 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 142 143 /// Parse a single byte from the stream. 144 template <typename T> 145 LogicalResult parseByte(T &value) { 146 if (empty()) 147 return emitError("attempting to parse a byte at the end of the bytecode"); 148 value = static_cast<T>(*dataIt++); 149 return success(); 150 } 151 /// Parse a range of bytes of 'length' into the given result. 152 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 153 if (length > size()) { 154 return emitError("attempting to parse ", length, " bytes when only ", 155 size(), " remain"); 156 } 157 result = {dataIt, length}; 158 dataIt += length; 159 return success(); 160 } 161 /// Parse a range of bytes of 'length' into the given result, which can be 162 /// assumed to be large enough to hold `length`. 163 LogicalResult parseBytes(size_t length, uint8_t *result) { 164 if (length > size()) { 165 return emitError("attempting to parse ", length, " bytes when only ", 166 size(), " remain"); 167 } 168 memcpy(result, dataIt, length); 169 dataIt += length; 170 return success(); 171 } 172 173 /// Parse an aligned blob of data, where the alignment was encoded alongside 174 /// the data. 175 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 176 uint64_t &alignment) { 177 uint64_t dataSize; 178 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 179 failed(alignTo(alignment))) 180 return failure(); 181 return parseBytes(dataSize, data); 182 } 183 184 /// Parse a variable length encoded integer from the byte stream. The first 185 /// encoded byte contains a prefix in the low bits indicating the encoded 186 /// length of the value. This length prefix is a bit sequence of '0's followed 187 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 188 /// (not including the prefix byte). All remaining bits in the first byte, 189 /// along with all of the bits in additional bytes, provide the value of the 190 /// integer encoded in little-endian order. 191 LogicalResult parseVarInt(uint64_t &result) { 192 // Parse the first byte of the encoding, which contains the length prefix. 193 if (failed(parseByte(result))) 194 return failure(); 195 196 // Handle the overwhelmingly common case where the value is stored in a 197 // single byte. In this case, the first bit is the `1` marker bit. 198 if (LLVM_LIKELY(result & 1)) { 199 result >>= 1; 200 return success(); 201 } 202 203 // Handle the overwhelming uncommon case where the value required all 8 204 // bytes (i.e. a really really big number). In this case, the marker byte is 205 // all zeros: `00000000`. 206 if (LLVM_UNLIKELY(result == 0)) 207 return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result)); 208 return parseMultiByteVarInt(result); 209 } 210 211 /// Parse a signed variable length encoded integer from the byte stream. A 212 /// signed varint is encoded as a normal varint with zigzag encoding applied, 213 /// i.e. the low bit of the value is used to indicate the sign. 214 LogicalResult parseSignedVarInt(uint64_t &result) { 215 if (failed(parseVarInt(result))) 216 return failure(); 217 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 218 result = (result >> 1) ^ (~(result & 1) + 1); 219 return success(); 220 } 221 222 /// Parse a variable length encoded integer whose low bit is used to encode an 223 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 224 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 225 if (failed(parseVarInt(result))) 226 return failure(); 227 flag = result & 1; 228 result >>= 1; 229 return success(); 230 } 231 232 /// Skip the first `length` bytes within the reader. 233 LogicalResult skipBytes(size_t length) { 234 if (length > size()) { 235 return emitError("attempting to skip ", length, " bytes when only ", 236 size(), " remain"); 237 } 238 dataIt += length; 239 return success(); 240 } 241 242 /// Parse a null-terminated string into `result` (without including the NUL 243 /// terminator). 244 LogicalResult parseNullTerminatedString(StringRef &result) { 245 const char *startIt = (const char *)dataIt; 246 const char *nulIt = (const char *)memchr(startIt, 0, size()); 247 if (!nulIt) 248 return emitError( 249 "malformed null-terminated string, no null character found"); 250 251 result = StringRef(startIt, nulIt - startIt); 252 dataIt = (const uint8_t *)nulIt + 1; 253 return success(); 254 } 255 256 /// Parse a section header, placing the kind of section in `sectionID` and the 257 /// contents of the section in `sectionData`. 258 LogicalResult parseSection(bytecode::Section::ID §ionID, 259 ArrayRef<uint8_t> §ionData) { 260 uint8_t sectionIDAndHasAlignment; 261 uint64_t length; 262 if (failed(parseByte(sectionIDAndHasAlignment)) || 263 failed(parseVarInt(length))) 264 return failure(); 265 266 // Extract the section ID and whether the section is aligned. The high bit 267 // of the ID is the alignment flag. 268 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 269 0b01111111); 270 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 271 272 // Check that the section is actually valid before trying to process its 273 // data. 274 if (sectionID >= bytecode::Section::kNumSections) 275 return emitError("invalid section ID: ", unsigned(sectionID)); 276 277 // Process the section alignment if present. 278 if (hasAlignment) { 279 uint64_t alignment; 280 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 281 return failure(); 282 } 283 284 // Parse the actual section data. 285 return parseBytes(static_cast<size_t>(length), sectionData); 286 } 287 288 Location getLoc() const { return fileLoc; } 289 290 private: 291 /// Parse a variable length encoded integer from the byte stream. This method 292 /// is a fallback when the number of bytes used to encode the value is greater 293 /// than 1, but less than the max (9). The provided `result` value can be 294 /// assumed to already contain the first byte of the value. 295 /// NOTE: This method is marked noinline to avoid pessimizing the common case 296 /// of single byte encoding. 297 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 298 // Count the number of trailing zeros in the marker byte, this indicates the 299 // number of trailing bytes that are part of the value. We use `uint32_t` 300 // here because we only care about the first byte, and so that be actually 301 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 302 // implementation). 303 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 304 assert(numBytes > 0 && numBytes <= 7 && 305 "unexpected number of trailing zeros in varint encoding"); 306 307 // Parse in the remaining bytes of the value. 308 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1))) 309 return failure(); 310 311 // Shift out the low-order bits that were used to mark how the value was 312 // encoded. 313 result >>= (numBytes + 1); 314 return success(); 315 } 316 317 /// The current data iterator, and an iterator to the end of the buffer. 318 const uint8_t *dataIt, *dataEnd; 319 320 /// A location for the bytecode used to report errors. 321 Location fileLoc; 322 }; 323 } // namespace 324 325 /// Resolve an index into the given entry list. `entry` may either be a 326 /// reference, in which case it is assigned to the corresponding value in 327 /// `entries`, or a pointer, in which case it is assigned to the address of the 328 /// element in `entries`. 329 template <typename RangeT, typename T> 330 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 331 uint64_t index, T &entry, 332 StringRef entryStr) { 333 if (index >= entries.size()) 334 return reader.emitError("invalid ", entryStr, " index: ", index); 335 336 // If the provided entry is a pointer, resolve to the address of the entry. 337 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 338 entry = entries[index]; 339 else 340 entry = &entries[index]; 341 return success(); 342 } 343 344 /// Parse and resolve an index into the given entry list. 345 template <typename RangeT, typename T> 346 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 347 T &entry, StringRef entryStr) { 348 uint64_t entryIdx; 349 if (failed(reader.parseVarInt(entryIdx))) 350 return failure(); 351 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // StringSectionReader 356 //===----------------------------------------------------------------------===// 357 358 namespace { 359 /// This class is used to read references to the string section from the 360 /// bytecode. 361 class StringSectionReader { 362 public: 363 /// Initialize the string section reader with the given section data. 364 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 365 366 /// Parse a shared string from the string section. The shared string is 367 /// encoded using an index to a corresponding string in the string section. 368 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 369 return parseEntry(reader, strings, result, "string"); 370 } 371 372 /// Parse a shared string from the string section. The shared string is 373 /// encoded using an index to a corresponding string in the string section. 374 /// This variant parses a flag compressed with the index. 375 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 376 bool &flag) { 377 uint64_t entryIdx; 378 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 379 return failure(); 380 return parseStringAtIndex(reader, entryIdx, result); 381 } 382 383 /// Parse a shared string from the string section. The shared string is 384 /// encoded using an index to a corresponding string in the string section. 385 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 386 StringRef &result) { 387 return resolveEntry(reader, strings, index, result, "string"); 388 } 389 390 private: 391 /// The table of strings referenced within the bytecode file. 392 SmallVector<StringRef> strings; 393 }; 394 } // namespace 395 396 LogicalResult StringSectionReader::initialize(Location fileLoc, 397 ArrayRef<uint8_t> sectionData) { 398 EncodingReader stringReader(sectionData, fileLoc); 399 400 // Parse the number of strings in the section. 401 uint64_t numStrings; 402 if (failed(stringReader.parseVarInt(numStrings))) 403 return failure(); 404 strings.resize(numStrings); 405 406 // Parse each of the strings. The sizes of the strings are encoded in reverse 407 // order, so that's the order we populate the table. 408 size_t stringDataEndOffset = sectionData.size(); 409 for (StringRef &string : llvm::reverse(strings)) { 410 uint64_t stringSize; 411 if (failed(stringReader.parseVarInt(stringSize))) 412 return failure(); 413 if (stringDataEndOffset < stringSize) { 414 return stringReader.emitError( 415 "string size exceeds the available data size"); 416 } 417 418 // Extract the string from the data, dropping the null character. 419 size_t stringOffset = stringDataEndOffset - stringSize; 420 string = StringRef( 421 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 422 stringSize - 1); 423 stringDataEndOffset = stringOffset; 424 } 425 426 // Check that the only remaining data was for the strings, i.e. the reader 427 // should be at the same offset as the first string. 428 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 429 return stringReader.emitError("unexpected trailing data between the " 430 "offsets for strings and their data"); 431 } 432 return success(); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // BytecodeDialect 437 //===----------------------------------------------------------------------===// 438 439 namespace { 440 class DialectReader; 441 442 /// This struct represents a dialect entry within the bytecode. 443 struct BytecodeDialect { 444 /// Load the dialect into the provided context if it hasn't been loaded yet. 445 /// Returns failure if the dialect couldn't be loaded *and* the provided 446 /// context does not allow unregistered dialects. The provided reader is used 447 /// for error emission if necessary. 448 LogicalResult load(DialectReader &reader, MLIRContext *ctx); 449 450 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 451 /// only be called after `load`. 452 Dialect *getLoadedDialect() const { 453 assert(dialect && 454 "expected `load` to be invoked before `getLoadedDialect`"); 455 return *dialect; 456 } 457 458 /// The loaded dialect entry. This field is std::nullopt if we haven't 459 /// attempted to load, nullptr if we failed to load, otherwise the loaded 460 /// dialect. 461 std::optional<Dialect *> dialect; 462 463 /// The bytecode interface of the dialect, or nullptr if the dialect does not 464 /// implement the bytecode interface. This field should only be checked if the 465 /// `dialect` field is not std::nullopt. 466 const BytecodeDialectInterface *interface = nullptr; 467 468 /// The name of the dialect. 469 StringRef name; 470 471 /// A buffer containing the encoding of the dialect version parsed. 472 ArrayRef<uint8_t> versionBuffer; 473 474 /// Lazy loaded dialect version from the handle above. 475 std::unique_ptr<DialectVersion> loadedVersion; 476 }; 477 478 /// This struct represents an operation name entry within the bytecode. 479 struct BytecodeOperationName { 480 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 481 std::optional<bool> wasRegistered) 482 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 483 484 /// The loaded operation name, or std::nullopt if it hasn't been processed 485 /// yet. 486 std::optional<OperationName> opName; 487 488 /// The dialect that owns this operation name. 489 BytecodeDialect *dialect; 490 491 /// The name of the operation, without the dialect prefix. 492 StringRef name; 493 494 /// Whether this operation was registered when the bytecode was produced. 495 /// This flag is populated when bytecode version >=5. 496 std::optional<bool> wasRegistered; 497 }; 498 } // namespace 499 500 /// Parse a single dialect group encoded in the byte stream. 501 static LogicalResult parseDialectGrouping( 502 EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, 503 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 504 // Parse the dialect and the number of entries in the group. 505 BytecodeDialect *dialect; 506 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 507 return failure(); 508 uint64_t numEntries; 509 if (failed(reader.parseVarInt(numEntries))) 510 return failure(); 511 512 for (uint64_t i = 0; i < numEntries; ++i) 513 if (failed(entryCallback(dialect))) 514 return failure(); 515 return success(); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // ResourceSectionReader 520 //===----------------------------------------------------------------------===// 521 522 namespace { 523 /// This class is used to read the resource section from the bytecode. 524 class ResourceSectionReader { 525 public: 526 /// Initialize the resource section reader with the given section data. 527 LogicalResult 528 initialize(Location fileLoc, const ParserConfig &config, 529 MutableArrayRef<BytecodeDialect> dialects, 530 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 531 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 532 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 533 534 /// Parse a dialect resource handle from the resource section. 535 LogicalResult parseResourceHandle(EncodingReader &reader, 536 AsmDialectResourceHandle &result) { 537 return parseEntry(reader, dialectResources, result, "resource handle"); 538 } 539 540 private: 541 /// The table of dialect resources within the bytecode file. 542 SmallVector<AsmDialectResourceHandle> dialectResources; 543 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 544 }; 545 546 class ParsedResourceEntry : public AsmParsedResourceEntry { 547 public: 548 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 549 EncodingReader &reader, StringSectionReader &stringReader, 550 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 551 : key(key), kind(kind), reader(reader), stringReader(stringReader), 552 bufferOwnerRef(bufferOwnerRef) {} 553 ~ParsedResourceEntry() override = default; 554 555 StringRef getKey() const final { return key; } 556 557 InFlightDiagnostic emitError() const final { return reader.emitError(); } 558 559 AsmResourceEntryKind getKind() const final { return kind; } 560 561 FailureOr<bool> parseAsBool() const final { 562 if (kind != AsmResourceEntryKind::Bool) 563 return emitError() << "expected a bool resource entry, but found a " 564 << toString(kind) << " entry instead"; 565 566 bool value; 567 if (failed(reader.parseByte(value))) 568 return failure(); 569 return value; 570 } 571 FailureOr<std::string> parseAsString() const final { 572 if (kind != AsmResourceEntryKind::String) 573 return emitError() << "expected a string resource entry, but found a " 574 << toString(kind) << " entry instead"; 575 576 StringRef string; 577 if (failed(stringReader.parseString(reader, string))) 578 return failure(); 579 return string.str(); 580 } 581 582 FailureOr<AsmResourceBlob> 583 parseAsBlob(BlobAllocatorFn allocator) const final { 584 if (kind != AsmResourceEntryKind::Blob) 585 return emitError() << "expected a blob resource entry, but found a " 586 << toString(kind) << " entry instead"; 587 588 ArrayRef<uint8_t> data; 589 uint64_t alignment; 590 if (failed(reader.parseBlobAndAlignment(data, alignment))) 591 return failure(); 592 593 // If we have an extendable reference to the buffer owner, we don't need to 594 // allocate a new buffer for the data, and can use the data directly. 595 if (bufferOwnerRef) { 596 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 597 data.size()); 598 599 // Allocate an unmanager buffer which captures a reference to the owner. 600 // For now we just mark this as immutable, but in the future we should 601 // explore marking this as mutable when desired. 602 return UnmanagedAsmResourceBlob::allocateWithAlign( 603 charData, alignment, 604 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 605 } 606 607 // Allocate memory for the blob using the provided allocator and copy the 608 // data into it. 609 AsmResourceBlob blob = allocator(data.size(), alignment); 610 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 611 blob.isMutable() && 612 "blob allocator did not return a properly aligned address"); 613 memcpy(blob.getMutableData().data(), data.data(), data.size()); 614 return blob; 615 } 616 617 private: 618 StringRef key; 619 AsmResourceEntryKind kind; 620 EncodingReader &reader; 621 StringSectionReader &stringReader; 622 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 623 }; 624 } // namespace 625 626 template <typename T> 627 static LogicalResult 628 parseResourceGroup(Location fileLoc, bool allowEmpty, 629 EncodingReader &offsetReader, EncodingReader &resourceReader, 630 StringSectionReader &stringReader, T *handler, 631 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 632 function_ref<StringRef(StringRef)> remapKey = {}, 633 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 634 uint64_t numResources; 635 if (failed(offsetReader.parseVarInt(numResources))) 636 return failure(); 637 638 for (uint64_t i = 0; i < numResources; ++i) { 639 StringRef key; 640 AsmResourceEntryKind kind; 641 uint64_t resourceOffset; 642 ArrayRef<uint8_t> data; 643 if (failed(stringReader.parseString(offsetReader, key)) || 644 failed(offsetReader.parseVarInt(resourceOffset)) || 645 failed(offsetReader.parseByte(kind)) || 646 failed(resourceReader.parseBytes(resourceOffset, data))) 647 return failure(); 648 649 // Process the resource key. 650 if ((processKeyFn && failed(processKeyFn(key)))) 651 return failure(); 652 653 // If the resource data is empty and we allow it, don't error out when 654 // parsing below, just skip it. 655 if (allowEmpty && data.empty()) 656 continue; 657 658 // Ignore the entry if we don't have a valid handler. 659 if (!handler) 660 continue; 661 662 // Otherwise, parse the resource value. 663 EncodingReader entryReader(data, fileLoc); 664 key = remapKey(key); 665 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 666 bufferOwnerRef); 667 if (failed(handler->parseResource(entry))) 668 return failure(); 669 if (!entryReader.empty()) { 670 return entryReader.emitError( 671 "unexpected trailing bytes in resource entry '", key, "'"); 672 } 673 } 674 return success(); 675 } 676 677 LogicalResult ResourceSectionReader::initialize( 678 Location fileLoc, const ParserConfig &config, 679 MutableArrayRef<BytecodeDialect> dialects, 680 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 681 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 682 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 683 EncodingReader resourceReader(sectionData, fileLoc); 684 EncodingReader offsetReader(offsetSectionData, fileLoc); 685 686 // Read the number of external resource providers. 687 uint64_t numExternalResourceGroups; 688 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 689 return failure(); 690 691 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 692 // provides most of the arguments. 693 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 694 function_ref<LogicalResult(StringRef)> keyFn = {}) { 695 auto resolveKey = [&](StringRef key) -> StringRef { 696 auto it = dialectResourceHandleRenamingMap.find(key); 697 if (it == dialectResourceHandleRenamingMap.end()) 698 return ""; 699 return it->second; 700 }; 701 702 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 703 stringReader, handler, bufferOwnerRef, resolveKey, 704 keyFn); 705 }; 706 707 // Read the external resources from the bytecode. 708 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 709 StringRef key; 710 if (failed(stringReader.parseString(offsetReader, key))) 711 return failure(); 712 713 // Get the handler for these resources. 714 // TODO: Should we require handling external resources in some scenarios? 715 AsmResourceParser *handler = config.getResourceParser(key); 716 if (!handler) { 717 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 718 << "'"; 719 } 720 721 if (failed(parseGroup(handler))) 722 return failure(); 723 } 724 725 // Read the dialect resources from the bytecode. 726 MLIRContext *ctx = fileLoc->getContext(); 727 while (!offsetReader.empty()) { 728 BytecodeDialect *dialect; 729 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 730 failed(dialect->load(dialectReader, ctx))) 731 return failure(); 732 Dialect *loadedDialect = dialect->getLoadedDialect(); 733 if (!loadedDialect) { 734 return resourceReader.emitError() 735 << "dialect '" << dialect->name << "' is unknown"; 736 } 737 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 738 if (!handler) { 739 return resourceReader.emitError() 740 << "unexpected resources for dialect '" << dialect->name << "'"; 741 } 742 743 // Ensure that each resource is declared before being processed. 744 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 745 FailureOr<AsmDialectResourceHandle> handle = 746 handler->declareResource(key); 747 if (failed(handle)) { 748 return resourceReader.emitError() 749 << "unknown 'resource' key '" << key << "' for dialect '" 750 << dialect->name << "'"; 751 } 752 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 753 dialectResources.push_back(*handle); 754 return success(); 755 }; 756 757 // Parse the resources for this dialect. We allow empty resources because we 758 // just treat these as declarations. 759 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 760 return failure(); 761 } 762 763 return success(); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // Attribute/Type Reader 768 //===----------------------------------------------------------------------===// 769 770 namespace { 771 /// This class provides support for reading attribute and type entries from the 772 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 773 /// this reader to manage when to actually parse them from the bytecode. 774 class AttrTypeReader { 775 /// This class represents a single attribute or type entry. 776 template <typename T> 777 struct Entry { 778 /// The entry, or null if it hasn't been resolved yet. 779 T entry = {}; 780 /// The parent dialect of this entry. 781 BytecodeDialect *dialect = nullptr; 782 /// A flag indicating if the entry was encoded using a custom encoding, 783 /// instead of using the textual assembly format. 784 bool hasCustomEncoding = false; 785 /// The raw data of this entry in the bytecode. 786 ArrayRef<uint8_t> data; 787 }; 788 using AttrEntry = Entry<Attribute>; 789 using TypeEntry = Entry<Type>; 790 791 public: 792 AttrTypeReader(StringSectionReader &stringReader, 793 ResourceSectionReader &resourceReader, Location fileLoc) 794 : stringReader(stringReader), resourceReader(resourceReader), 795 fileLoc(fileLoc) {} 796 797 /// Initialize the attribute and type information within the reader. 798 LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, 799 ArrayRef<uint8_t> sectionData, 800 ArrayRef<uint8_t> offsetSectionData); 801 802 /// Resolve the attribute or type at the given index. Returns nullptr on 803 /// failure. 804 Attribute resolveAttribute(size_t index) { 805 return resolveEntry(attributes, index, "Attribute"); 806 } 807 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 808 809 /// Parse a reference to an attribute or type using the given reader. 810 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 811 uint64_t attrIdx; 812 if (failed(reader.parseVarInt(attrIdx))) 813 return failure(); 814 result = resolveAttribute(attrIdx); 815 return success(!!result); 816 } 817 LogicalResult parseOptionalAttribute(EncodingReader &reader, 818 Attribute &result) { 819 uint64_t attrIdx; 820 bool flag; 821 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 822 return failure(); 823 if (!flag) 824 return success(); 825 result = resolveAttribute(attrIdx); 826 return success(!!result); 827 } 828 829 LogicalResult parseType(EncodingReader &reader, Type &result) { 830 uint64_t typeIdx; 831 if (failed(reader.parseVarInt(typeIdx))) 832 return failure(); 833 result = resolveType(typeIdx); 834 return success(!!result); 835 } 836 837 template <typename T> 838 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 839 Attribute baseResult; 840 if (failed(parseAttribute(reader, baseResult))) 841 return failure(); 842 if ((result = dyn_cast<T>(baseResult))) 843 return success(); 844 return reader.emitError("expected attribute of type: ", 845 llvm::getTypeName<T>(), ", but got: ", baseResult); 846 } 847 848 private: 849 /// Resolve the given entry at `index`. 850 template <typename T> 851 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 852 StringRef entryType); 853 854 /// Parse an entry using the given reader that was encoded using the textual 855 /// assembly format. 856 template <typename T> 857 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 858 StringRef entryType); 859 860 /// Parse an entry using the given reader that was encoded using a custom 861 /// bytecode format. 862 template <typename T> 863 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 864 StringRef entryType); 865 866 /// The string section reader used to resolve string references when parsing 867 /// custom encoded attribute/type entries. 868 StringSectionReader &stringReader; 869 870 /// The resource section reader used to resolve resource references when 871 /// parsing custom encoded attribute/type entries. 872 ResourceSectionReader &resourceReader; 873 874 /// The set of attribute and type entries. 875 SmallVector<AttrEntry> attributes; 876 SmallVector<TypeEntry> types; 877 878 /// A location used for error emission. 879 Location fileLoc; 880 }; 881 882 class DialectReader : public DialectBytecodeReader { 883 public: 884 DialectReader(AttrTypeReader &attrTypeReader, 885 StringSectionReader &stringReader, 886 ResourceSectionReader &resourceReader, EncodingReader &reader) 887 : attrTypeReader(attrTypeReader), stringReader(stringReader), 888 resourceReader(resourceReader), reader(reader) {} 889 890 InFlightDiagnostic emitError(const Twine &msg) override { 891 return reader.emitError(msg); 892 } 893 894 DialectReader withEncodingReader(EncodingReader &encReader) { 895 return DialectReader(attrTypeReader, stringReader, resourceReader, 896 encReader); 897 } 898 899 Location getLoc() const { return reader.getLoc(); } 900 901 //===--------------------------------------------------------------------===// 902 // IR 903 //===--------------------------------------------------------------------===// 904 905 LogicalResult readAttribute(Attribute &result) override { 906 return attrTypeReader.parseAttribute(reader, result); 907 } 908 LogicalResult readOptionalAttribute(Attribute &result) override { 909 return attrTypeReader.parseOptionalAttribute(reader, result); 910 } 911 LogicalResult readType(Type &result) override { 912 return attrTypeReader.parseType(reader, result); 913 } 914 915 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 916 AsmDialectResourceHandle handle; 917 if (failed(resourceReader.parseResourceHandle(reader, handle))) 918 return failure(); 919 return handle; 920 } 921 922 //===--------------------------------------------------------------------===// 923 // Primitives 924 //===--------------------------------------------------------------------===// 925 926 LogicalResult readVarInt(uint64_t &result) override { 927 return reader.parseVarInt(result); 928 } 929 930 LogicalResult readSignedVarInt(int64_t &result) override { 931 uint64_t unsignedResult; 932 if (failed(reader.parseSignedVarInt(unsignedResult))) 933 return failure(); 934 result = static_cast<int64_t>(unsignedResult); 935 return success(); 936 } 937 938 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 939 // Small values are encoded using a single byte. 940 if (bitWidth <= 8) { 941 uint8_t value; 942 if (failed(reader.parseByte(value))) 943 return failure(); 944 return APInt(bitWidth, value); 945 } 946 947 // Large values up to 64 bits are encoded using a single varint. 948 if (bitWidth <= 64) { 949 uint64_t value; 950 if (failed(reader.parseSignedVarInt(value))) 951 return failure(); 952 return APInt(bitWidth, value); 953 } 954 955 // Otherwise, for really big values we encode the array of active words in 956 // the value. 957 uint64_t numActiveWords; 958 if (failed(reader.parseVarInt(numActiveWords))) 959 return failure(); 960 SmallVector<uint64_t, 4> words(numActiveWords); 961 for (uint64_t i = 0; i < numActiveWords; ++i) 962 if (failed(reader.parseSignedVarInt(words[i]))) 963 return failure(); 964 return APInt(bitWidth, words); 965 } 966 967 FailureOr<APFloat> 968 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 969 FailureOr<APInt> intVal = 970 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 971 if (failed(intVal)) 972 return failure(); 973 return APFloat(semantics, *intVal); 974 } 975 976 LogicalResult readString(StringRef &result) override { 977 return stringReader.parseString(reader, result); 978 } 979 980 LogicalResult readBlob(ArrayRef<char> &result) override { 981 uint64_t dataSize; 982 ArrayRef<uint8_t> data; 983 if (failed(reader.parseVarInt(dataSize)) || 984 failed(reader.parseBytes(dataSize, data))) 985 return failure(); 986 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 987 data.size()); 988 return success(); 989 } 990 991 private: 992 AttrTypeReader &attrTypeReader; 993 StringSectionReader &stringReader; 994 ResourceSectionReader &resourceReader; 995 EncodingReader &reader; 996 }; 997 998 /// Wraps the properties section and handles reading properties out of it. 999 class PropertiesSectionReader { 1000 public: 1001 /// Initialize the properties section reader with the given section data. 1002 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1003 if (sectionData.empty()) 1004 return success(); 1005 EncodingReader propReader(sectionData, fileLoc); 1006 uint64_t count; 1007 if (failed(propReader.parseVarInt(count))) 1008 return failure(); 1009 // Parse the raw properties buffer. 1010 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1011 return failure(); 1012 1013 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1014 offsetTable.reserve(count); 1015 for (auto idx : llvm::seq<int64_t>(0, count)) { 1016 (void)idx; 1017 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1018 ArrayRef<uint8_t> rawProperties; 1019 uint64_t dataSize; 1020 if (failed(offsetsReader.parseVarInt(dataSize)) || 1021 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1022 return failure(); 1023 } 1024 if (!offsetsReader.empty()) 1025 return offsetsReader.emitError() 1026 << "Broken properties section: didn't exhaust the offsets table"; 1027 return success(); 1028 } 1029 1030 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1031 OperationName *opName, OperationState &opState) { 1032 uint64_t propertiesIdx; 1033 if (failed(dialectReader.readVarInt(propertiesIdx))) 1034 return failure(); 1035 if (propertiesIdx >= offsetTable.size()) 1036 return dialectReader.emitError("Properties idx out-of-bound for ") 1037 << opName->getStringRef(); 1038 size_t propertiesOffset = offsetTable[propertiesIdx]; 1039 if (propertiesIdx >= propertiesBuffers.size()) 1040 return dialectReader.emitError("Properties offset out-of-bound for ") 1041 << opName->getStringRef(); 1042 1043 // Acquire the sub-buffer that represent the requested properties. 1044 ArrayRef<char> rawProperties; 1045 { 1046 // "Seek" to the requested offset by getting a new reader with the right 1047 // sub-buffer. 1048 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1049 fileLoc); 1050 // Properties are stored as a sequence of {size + raw_data}. 1051 if (failed( 1052 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1053 return failure(); 1054 } 1055 // Setup a new reader to read from the `rawProperties` sub-buffer. 1056 EncodingReader reader( 1057 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1058 DialectReader propReader = dialectReader.withEncodingReader(reader); 1059 1060 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1061 if (iface) 1062 return iface->readProperties(propReader, opState); 1063 if (opName->isRegistered()) 1064 return propReader.emitError( 1065 "has properties but missing BytecodeOpInterface for ") 1066 << opName->getStringRef(); 1067 // Unregistered op are storing properties as an attribute. 1068 return propReader.readAttribute(opState.propertiesAttr); 1069 } 1070 1071 private: 1072 /// The properties buffer referenced within the bytecode file. 1073 ArrayRef<uint8_t> propertiesBuffers; 1074 1075 /// Table of offset in the buffer above. 1076 SmallVector<int64_t> offsetTable; 1077 }; 1078 } // namespace 1079 1080 LogicalResult 1081 AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, 1082 ArrayRef<uint8_t> sectionData, 1083 ArrayRef<uint8_t> offsetSectionData) { 1084 EncodingReader offsetReader(offsetSectionData, fileLoc); 1085 1086 // Parse the number of attribute and type entries. 1087 uint64_t numAttributes, numTypes; 1088 if (failed(offsetReader.parseVarInt(numAttributes)) || 1089 failed(offsetReader.parseVarInt(numTypes))) 1090 return failure(); 1091 attributes.resize(numAttributes); 1092 types.resize(numTypes); 1093 1094 // A functor used to accumulate the offsets for the entries in the given 1095 // range. 1096 uint64_t currentOffset = 0; 1097 auto parseEntries = [&](auto &&range) { 1098 size_t currentIndex = 0, endIndex = range.size(); 1099 1100 // Parse an individual entry. 1101 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1102 auto &entry = range[currentIndex++]; 1103 1104 uint64_t entrySize; 1105 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1106 entry.hasCustomEncoding))) 1107 return failure(); 1108 1109 // Verify that the offset is actually valid. 1110 if (currentOffset + entrySize > sectionData.size()) { 1111 return offsetReader.emitError( 1112 "Attribute or Type entry offset points past the end of section"); 1113 } 1114 1115 entry.data = sectionData.slice(currentOffset, entrySize); 1116 entry.dialect = dialect; 1117 currentOffset += entrySize; 1118 return success(); 1119 }; 1120 while (currentIndex != endIndex) 1121 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1122 return failure(); 1123 return success(); 1124 }; 1125 1126 // Process each of the attributes, and then the types. 1127 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1128 return failure(); 1129 1130 // Ensure that we read everything from the section. 1131 if (!offsetReader.empty()) { 1132 return offsetReader.emitError( 1133 "unexpected trailing data in the Attribute/Type offset section"); 1134 } 1135 return success(); 1136 } 1137 1138 template <typename T> 1139 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1140 StringRef entryType) { 1141 if (index >= entries.size()) { 1142 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1143 return {}; 1144 } 1145 1146 // If the entry has already been resolved, there is nothing left to do. 1147 Entry<T> &entry = entries[index]; 1148 if (entry.entry) 1149 return entry.entry; 1150 1151 // Parse the entry. 1152 EncodingReader reader(entry.data, fileLoc); 1153 1154 // Parse based on how the entry was encoded. 1155 if (entry.hasCustomEncoding) { 1156 if (failed(parseCustomEntry(entry, reader, entryType))) 1157 return T(); 1158 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1159 return T(); 1160 } 1161 1162 if (!reader.empty()) { 1163 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1164 return T(); 1165 } 1166 return entry.entry; 1167 } 1168 1169 template <typename T> 1170 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1171 StringRef entryType) { 1172 StringRef asmStr; 1173 if (failed(reader.parseNullTerminatedString(asmStr))) 1174 return failure(); 1175 1176 // Invoke the MLIR assembly parser to parse the entry text. 1177 size_t numRead = 0; 1178 MLIRContext *context = fileLoc->getContext(); 1179 if constexpr (std::is_same_v<T, Type>) 1180 result = 1181 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1182 else 1183 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1184 /*isKnownNullTerminated=*/true); 1185 if (!result) 1186 return failure(); 1187 1188 // Ensure there weren't dangling characters after the entry. 1189 if (numRead != asmStr.size()) { 1190 return reader.emitError("trailing characters found after ", entryType, 1191 " assembly format: ", asmStr.drop_front(numRead)); 1192 } 1193 return success(); 1194 } 1195 1196 template <typename T> 1197 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1198 EncodingReader &reader, 1199 StringRef entryType) { 1200 DialectReader dialectReader(*this, stringReader, resourceReader, reader); 1201 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1202 return failure(); 1203 // Ensure that the dialect implements the bytecode interface. 1204 if (!entry.dialect->interface) { 1205 return reader.emitError("dialect '", entry.dialect->name, 1206 "' does not implement the bytecode interface"); 1207 } 1208 1209 // Ask the dialect to parse the entry. If the dialect is versioned, parse 1210 // using the versioned encoding readers. 1211 if (entry.dialect->loadedVersion.get()) { 1212 if constexpr (std::is_same_v<T, Type>) 1213 entry.entry = entry.dialect->interface->readType( 1214 dialectReader, *entry.dialect->loadedVersion); 1215 else 1216 entry.entry = entry.dialect->interface->readAttribute( 1217 dialectReader, *entry.dialect->loadedVersion); 1218 1219 } else { 1220 if constexpr (std::is_same_v<T, Type>) 1221 entry.entry = entry.dialect->interface->readType(dialectReader); 1222 else 1223 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1224 } 1225 return success(!!entry.entry); 1226 } 1227 1228 //===----------------------------------------------------------------------===// 1229 // Bytecode Reader 1230 //===----------------------------------------------------------------------===// 1231 1232 /// This class is used to read a bytecode buffer and translate it into MLIR. 1233 class mlir::BytecodeReader::Impl { 1234 struct RegionReadState; 1235 using LazyLoadableOpsInfo = 1236 std::list<std::pair<Operation *, RegionReadState>>; 1237 using LazyLoadableOpsMap = 1238 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1239 1240 public: 1241 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1242 llvm::MemoryBufferRef buffer, 1243 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1244 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1245 attrTypeReader(stringReader, resourceReader, fileLoc), 1246 // Use the builtin unrealized conversion cast operation to represent 1247 // forward references to values that aren't yet defined. 1248 forwardRefOpState(UnknownLoc::get(config.getContext()), 1249 "builtin.unrealized_conversion_cast", ValueRange(), 1250 NoneType::get(config.getContext())), 1251 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1252 1253 /// Read the bytecode defined within `buffer` into the given block. 1254 LogicalResult read(Block *block, 1255 llvm::function_ref<bool(Operation *)> lazyOps); 1256 1257 /// Return the number of ops that haven't been materialized yet. 1258 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1259 1260 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1261 1262 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1263 /// newly found lazy operation. 1264 LogicalResult 1265 materialize(Operation *op, 1266 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1267 this->lazyOpsCallback = lazyOpsCallback; 1268 auto resetlazyOpsCallback = 1269 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1270 auto it = lazyLoadableOpsMap.find(op); 1271 assert(it != lazyLoadableOpsMap.end() && 1272 "materialize called on non-materializable op"); 1273 return materialize(it); 1274 } 1275 1276 /// Materialize all operations. 1277 LogicalResult materializeAll() { 1278 while (!lazyLoadableOpsMap.empty()) { 1279 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1280 return failure(); 1281 } 1282 return success(); 1283 } 1284 1285 /// Finalize the lazy-loading by calling back with every op that hasn't been 1286 /// materialized to let the client decide if the op should be deleted or 1287 /// materialized. The op is materialized if the callback returns true, deleted 1288 /// otherwise. 1289 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1290 while (!lazyLoadableOps.empty()) { 1291 Operation *op = lazyLoadableOps.begin()->first; 1292 if (shouldMaterialize(op)) { 1293 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1294 return failure(); 1295 continue; 1296 } 1297 op->dropAllReferences(); 1298 op->erase(); 1299 lazyLoadableOps.pop_front(); 1300 lazyLoadableOpsMap.erase(op); 1301 } 1302 return success(); 1303 } 1304 1305 private: 1306 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1307 assert(it != lazyLoadableOpsMap.end() && 1308 "materialize called on non-materializable op"); 1309 valueScopes.emplace_back(); 1310 std::vector<RegionReadState> regionStack; 1311 regionStack.push_back(std::move(it->getSecond()->second)); 1312 lazyLoadableOps.erase(it->getSecond()); 1313 lazyLoadableOpsMap.erase(it); 1314 auto result = parseRegions(regionStack, regionStack.back()); 1315 assert((regionStack.empty() || failed(result)) && 1316 "broken invariant: regionStack should be empty when parseRegions " 1317 "succeeds"); 1318 return result; 1319 } 1320 1321 /// Return the context for this config. 1322 MLIRContext *getContext() const { return config.getContext(); } 1323 1324 /// Parse the bytecode version. 1325 LogicalResult parseVersion(EncodingReader &reader); 1326 1327 //===--------------------------------------------------------------------===// 1328 // Dialect Section 1329 1330 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1331 1332 /// Parse an operation name reference using the given reader, and set the 1333 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1334 /// context where opName was registered. 1335 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1336 std::optional<bool> &wasRegistered); 1337 1338 //===--------------------------------------------------------------------===// 1339 // Attribute/Type Section 1340 1341 /// Parse an attribute or type using the given reader. 1342 template <typename T> 1343 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1344 return attrTypeReader.parseAttribute(reader, result); 1345 } 1346 LogicalResult parseType(EncodingReader &reader, Type &result) { 1347 return attrTypeReader.parseType(reader, result); 1348 } 1349 1350 //===--------------------------------------------------------------------===// 1351 // Resource Section 1352 1353 LogicalResult 1354 parseResourceSection(EncodingReader &reader, 1355 std::optional<ArrayRef<uint8_t>> resourceData, 1356 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1357 1358 //===--------------------------------------------------------------------===// 1359 // IR Section 1360 1361 /// This struct represents the current read state of a range of regions. This 1362 /// struct is used to enable iterative parsing of regions. 1363 struct RegionReadState { 1364 RegionReadState(Operation *op, EncodingReader *reader, 1365 bool isIsolatedFromAbove) 1366 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1367 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1368 bool isIsolatedFromAbove) 1369 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1370 isIsolatedFromAbove(isIsolatedFromAbove) {} 1371 1372 /// The current regions being read. 1373 MutableArrayRef<Region>::iterator curRegion, endRegion; 1374 /// This is the reader to use for this region, this pointer is pointing to 1375 /// the parent region reader unless the current region is IsolatedFromAbove, 1376 /// in which case the pointer is pointing to the `owningReader` which is a 1377 /// section dedicated to the current region. 1378 EncodingReader *reader; 1379 std::unique_ptr<EncodingReader> owningReader; 1380 1381 /// The number of values defined immediately within this region. 1382 unsigned numValues = 0; 1383 1384 /// The current blocks of the region being read. 1385 SmallVector<Block *> curBlocks; 1386 Region::iterator curBlock = {}; 1387 1388 /// The number of operations remaining to be read from the current block 1389 /// being read. 1390 uint64_t numOpsRemaining = 0; 1391 1392 /// A flag indicating if the regions being read are isolated from above. 1393 bool isIsolatedFromAbove = false; 1394 }; 1395 1396 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1397 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1398 RegionReadState &readState); 1399 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1400 RegionReadState &readState, 1401 bool &isIsolatedFromAbove); 1402 1403 LogicalResult parseRegion(RegionReadState &readState); 1404 LogicalResult parseBlockHeader(EncodingReader &reader, 1405 RegionReadState &readState); 1406 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1407 1408 //===--------------------------------------------------------------------===// 1409 // Value Processing 1410 1411 /// Parse an operand reference using the given reader. Returns nullptr in the 1412 /// case of failure. 1413 Value parseOperand(EncodingReader &reader); 1414 1415 /// Sequentially define the given value range. 1416 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1417 1418 /// Create a value to use for a forward reference. 1419 Value createForwardRef(); 1420 1421 //===--------------------------------------------------------------------===// 1422 // Use-list order helpers 1423 1424 /// This struct is a simple storage that contains information required to 1425 /// reorder the use-list of a value with respect to the pre-order traversal 1426 /// ordering. 1427 struct UseListOrderStorage { 1428 UseListOrderStorage(bool isIndexPairEncoding, 1429 SmallVector<unsigned, 4> &&indices) 1430 : indices(std::move(indices)), 1431 isIndexPairEncoding(isIndexPairEncoding){}; 1432 /// The vector containing the information required to reorder the 1433 /// use-list of a value. 1434 SmallVector<unsigned, 4> indices; 1435 1436 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1437 /// indexing, such as `dst = order[src]`. 1438 bool isIndexPairEncoding; 1439 }; 1440 1441 /// Parse use-list order from bytecode for a range of values if available. The 1442 /// range is expected to be either a block argument or an op result range. On 1443 /// success, return a map of the position in the range and the use-list order 1444 /// encoding. The function assumes to know the size of the range it is 1445 /// processing. 1446 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1447 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1448 uint64_t rangeSize); 1449 1450 /// Shuffle the use-chain according to the order parsed. 1451 LogicalResult sortUseListOrder(Value value); 1452 1453 /// Recursively visit all the values defined within topLevelOp and sort the 1454 /// use-list orders according to the indices parsed. 1455 LogicalResult processUseLists(Operation *topLevelOp); 1456 1457 //===--------------------------------------------------------------------===// 1458 // Fields 1459 1460 /// This class represents a single value scope, in which a value scope is 1461 /// delimited by isolated from above regions. 1462 struct ValueScope { 1463 /// Push a new region state onto this scope, reserving enough values for 1464 /// those defined within the current region of the provided state. 1465 void push(RegionReadState &readState) { 1466 nextValueIDs.push_back(values.size()); 1467 values.resize(values.size() + readState.numValues); 1468 } 1469 1470 /// Pop the values defined for the current region within the provided region 1471 /// state. 1472 void pop(RegionReadState &readState) { 1473 values.resize(values.size() - readState.numValues); 1474 nextValueIDs.pop_back(); 1475 } 1476 1477 /// The set of values defined in this scope. 1478 std::vector<Value> values; 1479 1480 /// The ID for the next defined value for each region current being 1481 /// processed in this scope. 1482 SmallVector<unsigned, 4> nextValueIDs; 1483 }; 1484 1485 /// The configuration of the parser. 1486 const ParserConfig &config; 1487 1488 /// A location to use when emitting errors. 1489 Location fileLoc; 1490 1491 /// Flag that indicates if lazyloading is enabled. 1492 bool lazyLoading; 1493 1494 /// Keep track of operations that have been lazy loaded (their regions haven't 1495 /// been materialized), along with the `RegionReadState` that allows to 1496 /// lazy-load the regions nested under the operation. 1497 LazyLoadableOpsInfo lazyLoadableOps; 1498 LazyLoadableOpsMap lazyLoadableOpsMap; 1499 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1500 1501 /// The reader used to process attribute and types within the bytecode. 1502 AttrTypeReader attrTypeReader; 1503 1504 /// The version of the bytecode being read. 1505 uint64_t version = 0; 1506 1507 /// The producer of the bytecode being read. 1508 StringRef producer; 1509 1510 /// The table of IR units referenced within the bytecode file. 1511 SmallVector<BytecodeDialect> dialects; 1512 SmallVector<BytecodeOperationName> opNames; 1513 1514 /// The reader used to process resources within the bytecode. 1515 ResourceSectionReader resourceReader; 1516 1517 /// Worklist of values with custom use-list orders to process before the end 1518 /// of the parsing. 1519 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1520 1521 /// The table of strings referenced within the bytecode file. 1522 StringSectionReader stringReader; 1523 1524 /// The table of properties referenced by the operation in the bytecode file. 1525 PropertiesSectionReader propertiesReader; 1526 1527 /// The current set of available IR value scopes. 1528 std::vector<ValueScope> valueScopes; 1529 1530 /// The global pre-order operation ordering. 1531 DenseMap<Operation *, unsigned> operationIDs; 1532 1533 /// A block containing the set of operations defined to create forward 1534 /// references. 1535 Block forwardRefOps; 1536 1537 /// A block containing previously created, and no longer used, forward 1538 /// reference operations. 1539 Block openForwardRefOps; 1540 1541 /// An operation state used when instantiating forward references. 1542 OperationState forwardRefOpState; 1543 1544 /// Reference to the input buffer. 1545 llvm::MemoryBufferRef buffer; 1546 1547 /// The optional owning source manager, which when present may be used to 1548 /// extend the lifetime of the input buffer. 1549 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1550 }; 1551 1552 LogicalResult BytecodeReader::Impl::read( 1553 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1554 EncodingReader reader(buffer.getBuffer(), fileLoc); 1555 this->lazyOpsCallback = lazyOpsCallback; 1556 auto resetlazyOpsCallback = 1557 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1558 1559 // Skip over the bytecode header, this should have already been checked. 1560 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1561 return failure(); 1562 // Parse the bytecode version and producer. 1563 if (failed(parseVersion(reader)) || 1564 failed(reader.parseNullTerminatedString(producer))) 1565 return failure(); 1566 1567 // Add a diagnostic handler that attaches a note that includes the original 1568 // producer of the bytecode. 1569 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1570 diag.attachNote() << "in bytecode version " << version 1571 << " produced by: " << producer; 1572 return failure(); 1573 }); 1574 1575 // Parse the raw data for each of the top-level sections of the bytecode. 1576 std::optional<ArrayRef<uint8_t>> 1577 sectionDatas[bytecode::Section::kNumSections]; 1578 while (!reader.empty()) { 1579 // Read the next section from the bytecode. 1580 bytecode::Section::ID sectionID; 1581 ArrayRef<uint8_t> sectionData; 1582 if (failed(reader.parseSection(sectionID, sectionData))) 1583 return failure(); 1584 1585 // Check for duplicate sections, we only expect one instance of each. 1586 if (sectionDatas[sectionID]) { 1587 return reader.emitError("duplicate top-level section: ", 1588 ::toString(sectionID)); 1589 } 1590 sectionDatas[sectionID] = sectionData; 1591 } 1592 // Check that all of the required sections were found. 1593 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1594 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1595 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1596 return reader.emitError("missing data for top-level section: ", 1597 ::toString(sectionID)); 1598 } 1599 } 1600 1601 // Process the string section first. 1602 if (failed(stringReader.initialize( 1603 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1604 return failure(); 1605 1606 // Process the properties section. 1607 if (sectionDatas[bytecode::Section::kProperties] && 1608 failed(propertiesReader.initialize( 1609 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1610 return failure(); 1611 1612 // Process the dialect section. 1613 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1614 return failure(); 1615 1616 // Process the resource section if present. 1617 if (failed(parseResourceSection( 1618 reader, sectionDatas[bytecode::Section::kResource], 1619 sectionDatas[bytecode::Section::kResourceOffset]))) 1620 return failure(); 1621 1622 // Process the attribute and type section. 1623 if (failed(attrTypeReader.initialize( 1624 dialects, *sectionDatas[bytecode::Section::kAttrType], 1625 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1626 return failure(); 1627 1628 // Finally, process the IR section. 1629 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1630 } 1631 1632 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1633 if (failed(reader.parseVarInt(version))) 1634 return failure(); 1635 1636 // Validate the bytecode version. 1637 uint64_t currentVersion = bytecode::kVersion; 1638 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1639 if (version < minSupportedVersion) { 1640 return reader.emitError("bytecode version ", version, 1641 " is older than the current version of ", 1642 currentVersion, ", and upgrade is not supported"); 1643 } 1644 if (version > currentVersion) { 1645 return reader.emitError("bytecode version ", version, 1646 " is newer than the current version ", 1647 currentVersion); 1648 } 1649 // Override any request to lazy-load if the bytecode version is too old. 1650 if (version < 2) 1651 lazyLoading = false; 1652 return success(); 1653 } 1654 1655 //===----------------------------------------------------------------------===// 1656 // Dialect Section 1657 1658 LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { 1659 if (dialect) 1660 return success(); 1661 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1662 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1663 return reader.emitError("dialect '") 1664 << name 1665 << "' is unknown. If this is intended, please call " 1666 "allowUnregisteredDialects() on the MLIRContext, or use " 1667 "-allow-unregistered-dialect with the MLIR tool used."; 1668 } 1669 dialect = loadedDialect; 1670 1671 // If the dialect was actually loaded, check to see if it has a bytecode 1672 // interface. 1673 if (loadedDialect) 1674 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1675 if (!versionBuffer.empty()) { 1676 if (!interface) 1677 return reader.emitError("dialect '") 1678 << name 1679 << "' does not implement the bytecode interface, " 1680 "but found a version entry"; 1681 EncodingReader encReader(versionBuffer, reader.getLoc()); 1682 DialectReader versionReader = reader.withEncodingReader(encReader); 1683 loadedVersion = interface->readVersion(versionReader); 1684 if (!loadedVersion) 1685 return failure(); 1686 } 1687 return success(); 1688 } 1689 1690 LogicalResult 1691 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1692 EncodingReader sectionReader(sectionData, fileLoc); 1693 1694 // Parse the number of dialects in the section. 1695 uint64_t numDialects; 1696 if (failed(sectionReader.parseVarInt(numDialects))) 1697 return failure(); 1698 dialects.resize(numDialects); 1699 1700 // Parse each of the dialects. 1701 for (uint64_t i = 0; i < numDialects; ++i) { 1702 /// Before version 1, there wasn't any versioning available for dialects, 1703 /// and the entryIdx represent the string itself. 1704 if (version == 0) { 1705 if (failed(stringReader.parseString(sectionReader, dialects[i].name))) 1706 return failure(); 1707 continue; 1708 } 1709 // Parse ID representing dialect and version. 1710 uint64_t dialectNameIdx; 1711 bool versionAvailable; 1712 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1713 versionAvailable))) 1714 return failure(); 1715 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1716 dialects[i].name))) 1717 return failure(); 1718 if (versionAvailable) { 1719 bytecode::Section::ID sectionID; 1720 if (failed( 1721 sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) 1722 return failure(); 1723 if (sectionID != bytecode::Section::kDialectVersions) { 1724 emitError(fileLoc, "expected dialect version section"); 1725 return failure(); 1726 } 1727 } 1728 } 1729 1730 // Parse the operation names, which are grouped by dialect. 1731 auto parseOpName = [&](BytecodeDialect *dialect) { 1732 StringRef opName; 1733 std::optional<bool> wasRegistered; 1734 // Prior to version 5, the information about wheter an op was registered or 1735 // not wasn't encoded. 1736 if (version < 5) { 1737 if (failed(stringReader.parseString(sectionReader, opName))) 1738 return failure(); 1739 } else { 1740 bool wasRegisteredFlag; 1741 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1742 wasRegisteredFlag))) 1743 return failure(); 1744 wasRegistered = wasRegisteredFlag; 1745 } 1746 opNames.emplace_back(dialect, opName, wasRegistered); 1747 return success(); 1748 }; 1749 // Avoid re-allocation in bytecode version > 3 where the number of ops are 1750 // known. 1751 if (version > 3) { 1752 uint64_t numOps; 1753 if (failed(sectionReader.parseVarInt(numOps))) 1754 return failure(); 1755 opNames.reserve(numOps); 1756 } 1757 while (!sectionReader.empty()) 1758 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1759 return failure(); 1760 return success(); 1761 } 1762 1763 FailureOr<OperationName> 1764 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1765 std::optional<bool> &wasRegistered) { 1766 BytecodeOperationName *opName = nullptr; 1767 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1768 return failure(); 1769 wasRegistered = opName->wasRegistered; 1770 // Check to see if this operation name has already been resolved. If we 1771 // haven't, load the dialect and build the operation name. 1772 if (!opName->opName) { 1773 // Load the dialect and its version. 1774 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1775 reader); 1776 if (failed(opName->dialect->load(dialectReader, getContext()))) 1777 return failure(); 1778 // If the opName is empty, this is because we use to accept names such as 1779 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1780 // format anymore but for now we'll be backward compatible. This can only 1781 // happen with unregistered dialects. 1782 if (opName->name.empty()) { 1783 if (opName->dialect->getLoadedDialect()) 1784 return emitError(fileLoc) << "has an empty opname for dialect '" 1785 << opName->dialect->name << "'\n"; 1786 1787 opName->opName.emplace(opName->dialect->name, getContext()); 1788 } else { 1789 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1790 getContext()); 1791 } 1792 } 1793 return *opName->opName; 1794 } 1795 1796 //===----------------------------------------------------------------------===// 1797 // Resource Section 1798 1799 LogicalResult BytecodeReader::Impl::parseResourceSection( 1800 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1801 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1802 // Ensure both sections are either present or not. 1803 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1804 if (resourceOffsetData) 1805 return emitError(fileLoc, "unexpected resource offset section when " 1806 "resource section is not present"); 1807 return emitError( 1808 fileLoc, 1809 "expected resource offset section when resource section is present"); 1810 } 1811 1812 // If the resource sections are absent, there is nothing to do. 1813 if (!resourceData) 1814 return success(); 1815 1816 // Initialize the resource reader with the resource sections. 1817 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1818 reader); 1819 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1820 *resourceData, *resourceOffsetData, 1821 dialectReader, bufferOwnerRef); 1822 } 1823 1824 //===----------------------------------------------------------------------===// 1825 // UseListOrder Helpers 1826 1827 FailureOr<BytecodeReader::Impl::UseListMapT> 1828 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1829 uint64_t numResults) { 1830 BytecodeReader::Impl::UseListMapT map; 1831 uint64_t numValuesToRead = 1; 1832 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1833 return failure(); 1834 1835 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1836 uint64_t resultIdx = 0; 1837 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1838 return failure(); 1839 1840 uint64_t numValues; 1841 bool indexPairEncoding; 1842 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1843 return failure(); 1844 1845 SmallVector<unsigned, 4> useListOrders; 1846 for (size_t idx = 0; idx < numValues; idx++) { 1847 uint64_t index; 1848 if (failed(reader.parseVarInt(index))) 1849 return failure(); 1850 useListOrders.push_back(index); 1851 } 1852 1853 // Store in a map the result index 1854 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1855 std::move(useListOrders))); 1856 } 1857 1858 return map; 1859 } 1860 1861 /// Sorts each use according to the order specified in the use-list parsed. If 1862 /// the custom use-list is not found, this means that the order needs to be 1863 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1864 /// on the same operation, the order will follow the reverse operand number 1865 /// ordering. 1866 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1867 // Early return for trivial use-lists. 1868 if (value.use_empty() || value.hasOneUse()) 1869 return success(); 1870 1871 bool hasIncomingOrder = 1872 valueToUseListMap.contains(value.getAsOpaquePointer()); 1873 1874 // Compute the current order of the use-list with respect to the global 1875 // ordering. Detect if the order is already sorted while doing so. 1876 bool alreadySorted = true; 1877 auto &firstUse = *value.use_begin(); 1878 uint64_t prevID = 1879 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1880 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1881 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1882 uint64_t currentID = bytecode::getUseID( 1883 item.value(), operationIDs.at(item.value().getOwner())); 1884 alreadySorted &= prevID > currentID; 1885 currentOrder.push_back({item.index(), currentID}); 1886 prevID = currentID; 1887 } 1888 1889 // If the order is already sorted, and there wasn't a custom order to apply 1890 // from the bytecode file, we are done. 1891 if (alreadySorted && !hasIncomingOrder) 1892 return success(); 1893 1894 // If not already sorted, sort the indices of the current order by descending 1895 // useIDs. 1896 if (!alreadySorted) 1897 std::sort( 1898 currentOrder.begin(), currentOrder.end(), 1899 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1900 1901 if (!hasIncomingOrder) { 1902 // If the bytecode file did not contain any custom use-list order, it means 1903 // that the order was descending useID. Hence, shuffle by the first index 1904 // of the `currentOrder` pair. 1905 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1906 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1907 value.shuffleUseList(shuffle); 1908 return success(); 1909 } 1910 1911 // Pull the custom order info from the map. 1912 UseListOrderStorage customOrder = 1913 valueToUseListMap.at(value.getAsOpaquePointer()); 1914 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1915 uint64_t numUses = 1916 std::distance(value.getUses().begin(), value.getUses().end()); 1917 1918 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1919 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1920 // as identity, and then apply the mapping encoded in the indices. 1921 if (customOrder.isIndexPairEncoding) { 1922 // Return failure if the number of indices was not representing pairs. 1923 if (shuffle.size() & 1) 1924 return failure(); 1925 1926 SmallVector<unsigned, 4> newShuffle(numUses); 1927 size_t idx = 0; 1928 std::iota(newShuffle.begin(), newShuffle.end(), idx); 1929 for (idx = 0; idx < shuffle.size(); idx += 2) 1930 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 1931 1932 shuffle = std::move(newShuffle); 1933 } 1934 1935 // Make sure that the indices represent a valid mapping. That is, the sum of 1936 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 1937 // duplicates are allowed in the list. 1938 DenseSet<unsigned> set; 1939 uint64_t accumulator = 0; 1940 for (const auto &elem : shuffle) { 1941 if (set.contains(elem)) 1942 return failure(); 1943 accumulator += elem; 1944 set.insert(elem); 1945 } 1946 if (numUses != shuffle.size() || 1947 accumulator != (((numUses - 1) * numUses) >> 1)) 1948 return failure(); 1949 1950 // Apply the current ordering map onto the shuffle vector to get the final 1951 // use-list sorting indices before shuffling. 1952 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 1953 currentOrder, [&](auto item) { return shuffle[item.first]; })); 1954 value.shuffleUseList(shuffle); 1955 return success(); 1956 } 1957 1958 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 1959 // Precompute operation IDs according to the pre-order walk of the IR. We 1960 // can't do this while parsing since parseRegions ordering is not strictly 1961 // equal to the pre-order walk. 1962 unsigned operationID = 0; 1963 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 1964 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 1965 1966 auto blockWalk = topLevelOp->walk([this](Block *block) { 1967 for (auto arg : block->getArguments()) 1968 if (failed(sortUseListOrder(arg))) 1969 return WalkResult::interrupt(); 1970 return WalkResult::advance(); 1971 }); 1972 1973 auto resultWalk = topLevelOp->walk([this](Operation *op) { 1974 for (auto result : op->getResults()) 1975 if (failed(sortUseListOrder(result))) 1976 return WalkResult::interrupt(); 1977 return WalkResult::advance(); 1978 }); 1979 1980 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 1981 } 1982 1983 //===----------------------------------------------------------------------===// 1984 // IR Section 1985 1986 LogicalResult 1987 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 1988 Block *block) { 1989 EncodingReader reader(sectionData, fileLoc); 1990 1991 // A stack of operation regions currently being read from the bytecode. 1992 std::vector<RegionReadState> regionStack; 1993 1994 // Parse the top-level block using a temporary module operation. 1995 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 1996 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 1997 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 1998 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 1999 if (failed(parseBlockHeader(reader, regionStack.back()))) 2000 return failure(); 2001 valueScopes.emplace_back(); 2002 valueScopes.back().push(regionStack.back()); 2003 2004 // Iteratively parse regions until everything has been resolved. 2005 while (!regionStack.empty()) 2006 if (failed(parseRegions(regionStack, regionStack.back()))) 2007 return failure(); 2008 if (!forwardRefOps.empty()) { 2009 return reader.emitError( 2010 "not all forward unresolved forward operand references"); 2011 } 2012 2013 // Sort use-lists according to what specified in bytecode. 2014 if (failed(processUseLists(*moduleOp))) 2015 return reader.emitError( 2016 "parsed use-list orders were invalid and could not be applied"); 2017 2018 // Resolve dialect version. 2019 for (const BytecodeDialect &byteCodeDialect : dialects) { 2020 // Parsing is complete, give an opportunity to each dialect to visit the 2021 // IR and perform upgrades. 2022 if (!byteCodeDialect.loadedVersion) 2023 continue; 2024 if (byteCodeDialect.interface && 2025 failed(byteCodeDialect.interface->upgradeFromVersion( 2026 *moduleOp, *byteCodeDialect.loadedVersion))) 2027 return failure(); 2028 } 2029 2030 // Verify that the parsed operations are valid. 2031 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2032 return failure(); 2033 2034 // Splice the parsed operations over to the provided top-level block. 2035 auto &parsedOps = moduleOp->getBody()->getOperations(); 2036 auto &destOps = block->getOperations(); 2037 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2038 return success(); 2039 } 2040 2041 LogicalResult 2042 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2043 RegionReadState &readState) { 2044 // Process regions, blocks, and operations until the end or if a nested 2045 // region is encountered. In this case we push a new state in regionStack and 2046 // return, the processing of the current region will resume afterward. 2047 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2048 // If the current block hasn't been setup yet, parse the header for this 2049 // region. The current block is already setup when this function was 2050 // interrupted to recurse down in a nested region and we resume the current 2051 // block after processing the nested region. 2052 if (readState.curBlock == Region::iterator()) { 2053 if (failed(parseRegion(readState))) 2054 return failure(); 2055 2056 // If the region is empty, there is nothing to more to do. 2057 if (readState.curRegion->empty()) 2058 continue; 2059 } 2060 2061 // Parse the blocks within the region. 2062 EncodingReader &reader = *readState.reader; 2063 do { 2064 while (readState.numOpsRemaining--) { 2065 // Read in the next operation. We don't read its regions directly, we 2066 // handle those afterwards as necessary. 2067 bool isIsolatedFromAbove = false; 2068 FailureOr<Operation *> op = 2069 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2070 if (failed(op)) 2071 return failure(); 2072 2073 // If the op has regions, add it to the stack for processing and return: 2074 // we stop the processing of the current region and resume it after the 2075 // inner one is completed. Unless LazyLoading is activated in which case 2076 // nested region parsing is delayed. 2077 if ((*op)->getNumRegions()) { 2078 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2079 2080 // Isolated regions are encoded as a section in version 2 and above. 2081 if (version >= 2 && isIsolatedFromAbove) { 2082 bytecode::Section::ID sectionID; 2083 ArrayRef<uint8_t> sectionData; 2084 if (failed(reader.parseSection(sectionID, sectionData))) 2085 return failure(); 2086 if (sectionID != bytecode::Section::kIR) 2087 return emitError(fileLoc, "expected IR section for region"); 2088 childState.owningReader = 2089 std::make_unique<EncodingReader>(sectionData, fileLoc); 2090 childState.reader = childState.owningReader.get(); 2091 } 2092 2093 if (lazyLoading) { 2094 // If the user has a callback set, they have the opportunity 2095 // to control lazyloading as we go. 2096 if (!lazyOpsCallback || !lazyOpsCallback(*op)) { 2097 lazyLoadableOps.push_back( 2098 std::make_pair(*op, std::move(childState))); 2099 lazyLoadableOpsMap.try_emplace(*op, 2100 std::prev(lazyLoadableOps.end())); 2101 continue; 2102 } 2103 } 2104 regionStack.push_back(std::move(childState)); 2105 2106 // If the op is isolated from above, push a new value scope. 2107 if (isIsolatedFromAbove) 2108 valueScopes.emplace_back(); 2109 return success(); 2110 } 2111 } 2112 2113 // Move to the next block of the region. 2114 if (++readState.curBlock == readState.curRegion->end()) 2115 break; 2116 if (failed(parseBlockHeader(reader, readState))) 2117 return failure(); 2118 } while (true); 2119 2120 // Reset the current block and any values reserved for this region. 2121 readState.curBlock = {}; 2122 valueScopes.back().pop(readState); 2123 } 2124 2125 // When the regions have been fully parsed, pop them off of the read stack. If 2126 // the regions were isolated from above, we also pop the last value scope. 2127 if (readState.isIsolatedFromAbove) { 2128 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2129 valueScopes.pop_back(); 2130 } 2131 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2132 regionStack.pop_back(); 2133 return success(); 2134 } 2135 2136 FailureOr<Operation *> 2137 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2138 RegionReadState &readState, 2139 bool &isIsolatedFromAbove) { 2140 // Parse the name of the operation. 2141 std::optional<bool> wasRegistered; 2142 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2143 if (failed(opName)) 2144 return failure(); 2145 2146 // Parse the operation mask, which indicates which components of the operation 2147 // are present. 2148 uint8_t opMask; 2149 if (failed(reader.parseByte(opMask))) 2150 return failure(); 2151 2152 /// Parse the location. 2153 LocationAttr opLoc; 2154 if (failed(parseAttribute(reader, opLoc))) 2155 return failure(); 2156 2157 // With the location and name resolved, we can start building the operation 2158 // state. 2159 OperationState opState(opLoc, *opName); 2160 2161 // Parse the attributes of the operation. 2162 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2163 DictionaryAttr dictAttr; 2164 if (failed(parseAttribute(reader, dictAttr))) 2165 return failure(); 2166 opState.attributes = dictAttr; 2167 } 2168 2169 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2170 // kHasProperties wasn't emitted in older bytecode, we should never get 2171 // there without also having the `wasRegistered` flag available. 2172 if (!wasRegistered) 2173 return emitError(fileLoc, 2174 "Unexpected missing `wasRegistered` opname flag at " 2175 "bytecode version ") 2176 << version << " with properties."; 2177 // When an operation is emitted without being registered, the properties are 2178 // stored as an attribute. Otherwise the op must implement the bytecode 2179 // interface and control the serialization. 2180 if (wasRegistered) { 2181 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2182 reader); 2183 if (failed( 2184 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2185 return failure(); 2186 } else { 2187 // If the operation wasn't registered when it was emitted, the properties 2188 // was serialized as an attribute. 2189 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2190 return failure(); 2191 } 2192 } 2193 2194 /// Parse the results of the operation. 2195 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2196 uint64_t numResults; 2197 if (failed(reader.parseVarInt(numResults))) 2198 return failure(); 2199 opState.types.resize(numResults); 2200 for (int i = 0, e = numResults; i < e; ++i) 2201 if (failed(parseType(reader, opState.types[i]))) 2202 return failure(); 2203 } 2204 2205 /// Parse the operands of the operation. 2206 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2207 uint64_t numOperands; 2208 if (failed(reader.parseVarInt(numOperands))) 2209 return failure(); 2210 opState.operands.resize(numOperands); 2211 for (int i = 0, e = numOperands; i < e; ++i) 2212 if (!(opState.operands[i] = parseOperand(reader))) 2213 return failure(); 2214 } 2215 2216 /// Parse the successors of the operation. 2217 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2218 uint64_t numSuccs; 2219 if (failed(reader.parseVarInt(numSuccs))) 2220 return failure(); 2221 opState.successors.resize(numSuccs); 2222 for (int i = 0, e = numSuccs; i < e; ++i) { 2223 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2224 "successor"))) 2225 return failure(); 2226 } 2227 } 2228 2229 /// Parse the use-list orders for the results of the operation. Use-list 2230 /// orders are available since version 3 of the bytecode. 2231 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2232 if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2233 size_t numResults = opState.types.size(); 2234 auto parseResult = parseUseListOrderForRange(reader, numResults); 2235 if (failed(parseResult)) 2236 return failure(); 2237 resultIdxToUseListMap = std::move(*parseResult); 2238 } 2239 2240 /// Parse the regions of the operation. 2241 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2242 uint64_t numRegions; 2243 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2244 return failure(); 2245 2246 opState.regions.reserve(numRegions); 2247 for (int i = 0, e = numRegions; i < e; ++i) 2248 opState.regions.push_back(std::make_unique<Region>()); 2249 } 2250 2251 // Create the operation at the back of the current block. 2252 Operation *op = Operation::create(opState); 2253 readState.curBlock->push_back(op); 2254 2255 // If the operation had results, update the value references. 2256 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2257 return failure(); 2258 2259 /// Store a map for every value that received a custom use-list order from the 2260 /// bytecode file. 2261 if (resultIdxToUseListMap.has_value()) { 2262 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2263 if (resultIdxToUseListMap->contains(idx)) { 2264 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2265 resultIdxToUseListMap->at(idx)); 2266 } 2267 } 2268 } 2269 return op; 2270 } 2271 2272 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2273 EncodingReader &reader = *readState.reader; 2274 2275 // Parse the number of blocks in the region. 2276 uint64_t numBlocks; 2277 if (failed(reader.parseVarInt(numBlocks))) 2278 return failure(); 2279 2280 // If the region is empty, there is nothing else to do. 2281 if (numBlocks == 0) 2282 return success(); 2283 2284 // Parse the number of values defined in this region. 2285 uint64_t numValues; 2286 if (failed(reader.parseVarInt(numValues))) 2287 return failure(); 2288 readState.numValues = numValues; 2289 2290 // Create the blocks within this region. We do this before processing so that 2291 // we can rely on the blocks existing when creating operations. 2292 readState.curBlocks.clear(); 2293 readState.curBlocks.reserve(numBlocks); 2294 for (uint64_t i = 0; i < numBlocks; ++i) { 2295 readState.curBlocks.push_back(new Block()); 2296 readState.curRegion->push_back(readState.curBlocks.back()); 2297 } 2298 2299 // Prepare the current value scope for this region. 2300 valueScopes.back().push(readState); 2301 2302 // Parse the entry block of the region. 2303 readState.curBlock = readState.curRegion->begin(); 2304 return parseBlockHeader(reader, readState); 2305 } 2306 2307 LogicalResult 2308 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2309 RegionReadState &readState) { 2310 bool hasArgs; 2311 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2312 return failure(); 2313 2314 // Parse the arguments of the block. 2315 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2316 return failure(); 2317 2318 // Uselist orders are available since version 3 of the bytecode. 2319 if (version < 3) 2320 return success(); 2321 2322 uint8_t hasUseListOrders = 0; 2323 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2324 return failure(); 2325 2326 if (!hasUseListOrders) 2327 return success(); 2328 2329 Block &blk = *readState.curBlock; 2330 auto argIdxToUseListMap = 2331 parseUseListOrderForRange(reader, blk.getNumArguments()); 2332 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2333 return failure(); 2334 2335 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2336 if (argIdxToUseListMap->contains(idx)) 2337 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2338 argIdxToUseListMap->at(idx)); 2339 2340 // We don't parse the operations of the block here, that's done elsewhere. 2341 return success(); 2342 } 2343 2344 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2345 Block *block) { 2346 // Parse the value ID for the first argument, and the number of arguments. 2347 uint64_t numArgs; 2348 if (failed(reader.parseVarInt(numArgs))) 2349 return failure(); 2350 2351 SmallVector<Type> argTypes; 2352 SmallVector<Location> argLocs; 2353 argTypes.reserve(numArgs); 2354 argLocs.reserve(numArgs); 2355 2356 Location unknownLoc = UnknownLoc::get(config.getContext()); 2357 while (numArgs--) { 2358 Type argType; 2359 LocationAttr argLoc = unknownLoc; 2360 if (version > 3) { 2361 // Parse the type with hasLoc flag to determine if it has type. 2362 uint64_t typeIdx; 2363 bool hasLoc; 2364 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2365 !(argType = attrTypeReader.resolveType(typeIdx))) 2366 return failure(); 2367 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2368 return failure(); 2369 } else { 2370 // All args has type and location. 2371 if (failed(parseType(reader, argType)) || 2372 failed(parseAttribute(reader, argLoc))) 2373 return failure(); 2374 } 2375 argTypes.push_back(argType); 2376 argLocs.push_back(argLoc); 2377 } 2378 block->addArguments(argTypes, argLocs); 2379 return defineValues(reader, block->getArguments()); 2380 } 2381 2382 //===----------------------------------------------------------------------===// 2383 // Value Processing 2384 2385 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2386 std::vector<Value> &values = valueScopes.back().values; 2387 Value *value = nullptr; 2388 if (failed(parseEntry(reader, values, value, "value"))) 2389 return Value(); 2390 2391 // Create a new forward reference if necessary. 2392 if (!*value) 2393 *value = createForwardRef(); 2394 return *value; 2395 } 2396 2397 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2398 ValueRange newValues) { 2399 ValueScope &valueScope = valueScopes.back(); 2400 std::vector<Value> &values = valueScope.values; 2401 2402 unsigned &valueID = valueScope.nextValueIDs.back(); 2403 unsigned valueIDEnd = valueID + newValues.size(); 2404 if (valueIDEnd > values.size()) { 2405 return reader.emitError( 2406 "value index range was outside of the expected range for " 2407 "the parent region, got [", 2408 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2409 values.size() - 1); 2410 } 2411 2412 // Assign the values and update any forward references. 2413 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2414 Value newValue = newValues[i]; 2415 2416 // Check to see if a definition for this value already exists. 2417 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2418 Operation *forwardRefOp = oldValue.getDefiningOp(); 2419 2420 // Assert that this is a forward reference operation. Given how we compute 2421 // definition ids (incrementally as we parse), it shouldn't be possible 2422 // for the value to be defined any other way. 2423 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2424 "value index was already defined?"); 2425 2426 oldValue.replaceAllUsesWith(newValue); 2427 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2428 } 2429 } 2430 return success(); 2431 } 2432 2433 Value BytecodeReader::Impl::createForwardRef() { 2434 // Check for an avaliable existing operation to use. Otherwise, create a new 2435 // fake operation to use for the reference. 2436 if (!openForwardRefOps.empty()) { 2437 Operation *op = &openForwardRefOps.back(); 2438 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2439 } else { 2440 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2441 } 2442 return forwardRefOps.back().getResult(0); 2443 } 2444 2445 //===----------------------------------------------------------------------===// 2446 // Entry Points 2447 //===----------------------------------------------------------------------===// 2448 2449 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2450 2451 BytecodeReader::BytecodeReader( 2452 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2453 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2454 Location sourceFileLoc = 2455 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2456 /*line=*/0, /*column=*/0); 2457 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2458 bufferOwnerRef); 2459 } 2460 2461 LogicalResult BytecodeReader::readTopLevel( 2462 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2463 return impl->read(block, lazyOpsCallback); 2464 } 2465 2466 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2467 return impl->getNumOpsToMaterialize(); 2468 } 2469 2470 bool BytecodeReader::isMaterializable(Operation *op) { 2471 return impl->isMaterializable(op); 2472 } 2473 2474 LogicalResult BytecodeReader::materialize( 2475 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2476 return impl->materialize(op, lazyOpsCallback); 2477 } 2478 2479 LogicalResult 2480 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2481 return impl->finalize(shouldMaterialize); 2482 } 2483 2484 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2485 return buffer.getBuffer().startswith("ML\xefR"); 2486 } 2487 2488 /// Read the bytecode from the provided memory buffer reference. 2489 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2490 /// and may be used to extend the lifetime of the buffer. 2491 static LogicalResult 2492 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2493 const ParserConfig &config, 2494 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2495 Location sourceFileLoc = 2496 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2497 /*line=*/0, /*column=*/0); 2498 if (!isBytecode(buffer)) { 2499 return emitError(sourceFileLoc, 2500 "input buffer is not an MLIR bytecode file"); 2501 } 2502 2503 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2504 buffer, bufferOwnerRef); 2505 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2506 } 2507 2508 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2509 const ParserConfig &config) { 2510 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2511 } 2512 LogicalResult 2513 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2514 Block *block, const ParserConfig &config) { 2515 return readBytecodeFileImpl( 2516 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2517 sourceMgr); 2518 } 2519