1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // TODO: Support for big-endian architectures. 10 11 #include "mlir/Bytecode/BytecodeReader.h" 12 #include "mlir/AsmParser/AsmParser.h" 13 #include "mlir/Bytecode/BytecodeImplementation.h" 14 #include "mlir/Bytecode/BytecodeOpInterface.h" 15 #include "mlir/Bytecode/Encoding.h" 16 #include "mlir/IR/BuiltinDialect.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/OpImplementation.h" 20 #include "mlir/IR/Verifier.h" 21 #include "mlir/IR/Visitors.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Support/LogicalResult.h" 24 #include "llvm/ADT/ArrayRef.h" 25 #include "llvm/ADT/MapVector.h" 26 #include "llvm/ADT/ScopeExit.h" 27 #include "llvm/ADT/SmallString.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/Support/MemoryBufferRef.h" 31 #include "llvm/Support/SaveAndRestore.h" 32 #include "llvm/Support/SourceMgr.h" 33 #include <cstddef> 34 #include <list> 35 #include <memory> 36 #include <numeric> 37 #include <optional> 38 39 #define DEBUG_TYPE "mlir-bytecode-reader" 40 41 using namespace mlir; 42 43 /// Stringify the given section ID. 44 static std::string toString(bytecode::Section::ID sectionID) { 45 switch (sectionID) { 46 case bytecode::Section::kString: 47 return "String (0)"; 48 case bytecode::Section::kDialect: 49 return "Dialect (1)"; 50 case bytecode::Section::kAttrType: 51 return "AttrType (2)"; 52 case bytecode::Section::kAttrTypeOffset: 53 return "AttrTypeOffset (3)"; 54 case bytecode::Section::kIR: 55 return "IR (4)"; 56 case bytecode::Section::kResource: 57 return "Resource (5)"; 58 case bytecode::Section::kResourceOffset: 59 return "ResourceOffset (6)"; 60 case bytecode::Section::kDialectVersions: 61 return "DialectVersions (7)"; 62 case bytecode::Section::kProperties: 63 return "Properties (8)"; 64 default: 65 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 66 } 67 } 68 69 /// Returns true if the given top-level section ID is optional. 70 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 71 switch (sectionID) { 72 case bytecode::Section::kString: 73 case bytecode::Section::kDialect: 74 case bytecode::Section::kAttrType: 75 case bytecode::Section::kAttrTypeOffset: 76 case bytecode::Section::kIR: 77 return false; 78 case bytecode::Section::kResource: 79 case bytecode::Section::kResourceOffset: 80 case bytecode::Section::kDialectVersions: 81 return true; 82 case bytecode::Section::kProperties: 83 return version < 4; 84 default: 85 llvm_unreachable("unknown section ID"); 86 } 87 } 88 89 //===----------------------------------------------------------------------===// 90 // EncodingReader 91 //===----------------------------------------------------------------------===// 92 93 namespace { 94 class EncodingReader { 95 public: 96 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 97 : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} 98 explicit EncodingReader(StringRef contents, Location fileLoc) 99 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 100 contents.size()}, 101 fileLoc) {} 102 103 /// Returns true if the entire section has been read. 104 bool empty() const { return dataIt == dataEnd; } 105 106 /// Returns the remaining size of the bytecode. 107 size_t size() const { return dataEnd - dataIt; } 108 109 /// Align the current reader position to the specified alignment. 110 LogicalResult alignTo(unsigned alignment) { 111 if (!llvm::isPowerOf2_32(alignment)) 112 return emitError("expected alignment to be a power-of-two"); 113 114 // Shift the reader position to the next alignment boundary. 115 while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { 116 uint8_t padding; 117 if (failed(parseByte(padding))) 118 return failure(); 119 if (padding != bytecode::kAlignmentByte) { 120 return emitError("expected alignment byte (0xCB), but got: '0x" + 121 llvm::utohexstr(padding) + "'"); 122 } 123 } 124 125 // Ensure the data iterator is now aligned. This case is unlikely because we 126 // *just* went through the effort to align the data iterator. 127 if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) { 128 return emitError("expected data iterator aligned to ", alignment, 129 ", but got pointer: '0x" + 130 llvm::utohexstr((uintptr_t)dataIt) + "'"); 131 } 132 133 return success(); 134 } 135 136 /// Emit an error using the given arguments. 137 template <typename... Args> 138 InFlightDiagnostic emitError(Args &&...args) const { 139 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 140 } 141 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 142 143 /// Parse a single byte from the stream. 144 template <typename T> 145 LogicalResult parseByte(T &value) { 146 if (empty()) 147 return emitError("attempting to parse a byte at the end of the bytecode"); 148 value = static_cast<T>(*dataIt++); 149 return success(); 150 } 151 /// Parse a range of bytes of 'length' into the given result. 152 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 153 if (length > size()) { 154 return emitError("attempting to parse ", length, " bytes when only ", 155 size(), " remain"); 156 } 157 result = {dataIt, length}; 158 dataIt += length; 159 return success(); 160 } 161 /// Parse a range of bytes of 'length' into the given result, which can be 162 /// assumed to be large enough to hold `length`. 163 LogicalResult parseBytes(size_t length, uint8_t *result) { 164 if (length > size()) { 165 return emitError("attempting to parse ", length, " bytes when only ", 166 size(), " remain"); 167 } 168 memcpy(result, dataIt, length); 169 dataIt += length; 170 return success(); 171 } 172 173 /// Parse an aligned blob of data, where the alignment was encoded alongside 174 /// the data. 175 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 176 uint64_t &alignment) { 177 uint64_t dataSize; 178 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 179 failed(alignTo(alignment))) 180 return failure(); 181 return parseBytes(dataSize, data); 182 } 183 184 /// Parse a variable length encoded integer from the byte stream. The first 185 /// encoded byte contains a prefix in the low bits indicating the encoded 186 /// length of the value. This length prefix is a bit sequence of '0's followed 187 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 188 /// (not including the prefix byte). All remaining bits in the first byte, 189 /// along with all of the bits in additional bytes, provide the value of the 190 /// integer encoded in little-endian order. 191 LogicalResult parseVarInt(uint64_t &result) { 192 // Parse the first byte of the encoding, which contains the length prefix. 193 if (failed(parseByte(result))) 194 return failure(); 195 196 // Handle the overwhelmingly common case where the value is stored in a 197 // single byte. In this case, the first bit is the `1` marker bit. 198 if (LLVM_LIKELY(result & 1)) { 199 result >>= 1; 200 return success(); 201 } 202 203 // Handle the overwhelming uncommon case where the value required all 8 204 // bytes (i.e. a really really big number). In this case, the marker byte is 205 // all zeros: `00000000`. 206 if (LLVM_UNLIKELY(result == 0)) 207 return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result)); 208 return parseMultiByteVarInt(result); 209 } 210 211 /// Parse a signed variable length encoded integer from the byte stream. A 212 /// signed varint is encoded as a normal varint with zigzag encoding applied, 213 /// i.e. the low bit of the value is used to indicate the sign. 214 LogicalResult parseSignedVarInt(uint64_t &result) { 215 if (failed(parseVarInt(result))) 216 return failure(); 217 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 218 result = (result >> 1) ^ (~(result & 1) + 1); 219 return success(); 220 } 221 222 /// Parse a variable length encoded integer whose low bit is used to encode an 223 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 224 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 225 if (failed(parseVarInt(result))) 226 return failure(); 227 flag = result & 1; 228 result >>= 1; 229 return success(); 230 } 231 232 /// Skip the first `length` bytes within the reader. 233 LogicalResult skipBytes(size_t length) { 234 if (length > size()) { 235 return emitError("attempting to skip ", length, " bytes when only ", 236 size(), " remain"); 237 } 238 dataIt += length; 239 return success(); 240 } 241 242 /// Parse a null-terminated string into `result` (without including the NUL 243 /// terminator). 244 LogicalResult parseNullTerminatedString(StringRef &result) { 245 const char *startIt = (const char *)dataIt; 246 const char *nulIt = (const char *)memchr(startIt, 0, size()); 247 if (!nulIt) 248 return emitError( 249 "malformed null-terminated string, no null character found"); 250 251 result = StringRef(startIt, nulIt - startIt); 252 dataIt = (const uint8_t *)nulIt + 1; 253 return success(); 254 } 255 256 /// Parse a section header, placing the kind of section in `sectionID` and the 257 /// contents of the section in `sectionData`. 258 LogicalResult parseSection(bytecode::Section::ID §ionID, 259 ArrayRef<uint8_t> §ionData) { 260 uint8_t sectionIDAndHasAlignment; 261 uint64_t length; 262 if (failed(parseByte(sectionIDAndHasAlignment)) || 263 failed(parseVarInt(length))) 264 return failure(); 265 266 // Extract the section ID and whether the section is aligned. The high bit 267 // of the ID is the alignment flag. 268 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 269 0b01111111); 270 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 271 272 // Check that the section is actually valid before trying to process its 273 // data. 274 if (sectionID >= bytecode::Section::kNumSections) 275 return emitError("invalid section ID: ", unsigned(sectionID)); 276 277 // Process the section alignment if present. 278 if (hasAlignment) { 279 uint64_t alignment; 280 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 281 return failure(); 282 } 283 284 // Parse the actual section data. 285 return parseBytes(static_cast<size_t>(length), sectionData); 286 } 287 288 Location getLoc() const { return fileLoc; } 289 290 private: 291 /// Parse a variable length encoded integer from the byte stream. This method 292 /// is a fallback when the number of bytes used to encode the value is greater 293 /// than 1, but less than the max (9). The provided `result` value can be 294 /// assumed to already contain the first byte of the value. 295 /// NOTE: This method is marked noinline to avoid pessimizing the common case 296 /// of single byte encoding. 297 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 298 // Count the number of trailing zeros in the marker byte, this indicates the 299 // number of trailing bytes that are part of the value. We use `uint32_t` 300 // here because we only care about the first byte, and so that be actually 301 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 302 // implementation). 303 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 304 assert(numBytes > 0 && numBytes <= 7 && 305 "unexpected number of trailing zeros in varint encoding"); 306 307 // Parse in the remaining bytes of the value. 308 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1))) 309 return failure(); 310 311 // Shift out the low-order bits that were used to mark how the value was 312 // encoded. 313 result >>= (numBytes + 1); 314 return success(); 315 } 316 317 /// The current data iterator, and an iterator to the end of the buffer. 318 const uint8_t *dataIt, *dataEnd; 319 320 /// A location for the bytecode used to report errors. 321 Location fileLoc; 322 }; 323 } // namespace 324 325 /// Resolve an index into the given entry list. `entry` may either be a 326 /// reference, in which case it is assigned to the corresponding value in 327 /// `entries`, or a pointer, in which case it is assigned to the address of the 328 /// element in `entries`. 329 template <typename RangeT, typename T> 330 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 331 uint64_t index, T &entry, 332 StringRef entryStr) { 333 if (index >= entries.size()) 334 return reader.emitError("invalid ", entryStr, " index: ", index); 335 336 // If the provided entry is a pointer, resolve to the address of the entry. 337 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 338 entry = entries[index]; 339 else 340 entry = &entries[index]; 341 return success(); 342 } 343 344 /// Parse and resolve an index into the given entry list. 345 template <typename RangeT, typename T> 346 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 347 T &entry, StringRef entryStr) { 348 uint64_t entryIdx; 349 if (failed(reader.parseVarInt(entryIdx))) 350 return failure(); 351 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // StringSectionReader 356 //===----------------------------------------------------------------------===// 357 358 namespace { 359 /// This class is used to read references to the string section from the 360 /// bytecode. 361 class StringSectionReader { 362 public: 363 /// Initialize the string section reader with the given section data. 364 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 365 366 /// Parse a shared string from the string section. The shared string is 367 /// encoded using an index to a corresponding string in the string section. 368 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 369 return parseEntry(reader, strings, result, "string"); 370 } 371 372 /// Parse a shared string from the string section. The shared string is 373 /// encoded using an index to a corresponding string in the string section. 374 /// This variant parses a flag compressed with the index. 375 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 376 bool &flag) { 377 uint64_t entryIdx; 378 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 379 return failure(); 380 return parseStringAtIndex(reader, entryIdx, result); 381 } 382 383 /// Parse a shared string from the string section. The shared string is 384 /// encoded using an index to a corresponding string in the string section. 385 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 386 StringRef &result) { 387 return resolveEntry(reader, strings, index, result, "string"); 388 } 389 390 private: 391 /// The table of strings referenced within the bytecode file. 392 SmallVector<StringRef> strings; 393 }; 394 } // namespace 395 396 LogicalResult StringSectionReader::initialize(Location fileLoc, 397 ArrayRef<uint8_t> sectionData) { 398 EncodingReader stringReader(sectionData, fileLoc); 399 400 // Parse the number of strings in the section. 401 uint64_t numStrings; 402 if (failed(stringReader.parseVarInt(numStrings))) 403 return failure(); 404 strings.resize(numStrings); 405 406 // Parse each of the strings. The sizes of the strings are encoded in reverse 407 // order, so that's the order we populate the table. 408 size_t stringDataEndOffset = sectionData.size(); 409 for (StringRef &string : llvm::reverse(strings)) { 410 uint64_t stringSize; 411 if (failed(stringReader.parseVarInt(stringSize))) 412 return failure(); 413 if (stringDataEndOffset < stringSize) { 414 return stringReader.emitError( 415 "string size exceeds the available data size"); 416 } 417 418 // Extract the string from the data, dropping the null character. 419 size_t stringOffset = stringDataEndOffset - stringSize; 420 string = StringRef( 421 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 422 stringSize - 1); 423 stringDataEndOffset = stringOffset; 424 } 425 426 // Check that the only remaining data was for the strings, i.e. the reader 427 // should be at the same offset as the first string. 428 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 429 return stringReader.emitError("unexpected trailing data between the " 430 "offsets for strings and their data"); 431 } 432 return success(); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // BytecodeDialect 437 //===----------------------------------------------------------------------===// 438 439 namespace { 440 class DialectReader; 441 442 /// This struct represents a dialect entry within the bytecode. 443 struct BytecodeDialect { 444 /// Load the dialect into the provided context if it hasn't been loaded yet. 445 /// Returns failure if the dialect couldn't be loaded *and* the provided 446 /// context does not allow unregistered dialects. The provided reader is used 447 /// for error emission if necessary. 448 LogicalResult load(DialectReader &reader, MLIRContext *ctx); 449 450 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 451 /// only be called after `load`. 452 Dialect *getLoadedDialect() const { 453 assert(dialect && 454 "expected `load` to be invoked before `getLoadedDialect`"); 455 return *dialect; 456 } 457 458 /// The loaded dialect entry. This field is std::nullopt if we haven't 459 /// attempted to load, nullptr if we failed to load, otherwise the loaded 460 /// dialect. 461 std::optional<Dialect *> dialect; 462 463 /// The bytecode interface of the dialect, or nullptr if the dialect does not 464 /// implement the bytecode interface. This field should only be checked if the 465 /// `dialect` field is not std::nullopt. 466 const BytecodeDialectInterface *interface = nullptr; 467 468 /// The name of the dialect. 469 StringRef name; 470 471 /// A buffer containing the encoding of the dialect version parsed. 472 ArrayRef<uint8_t> versionBuffer; 473 474 /// Lazy loaded dialect version from the handle above. 475 std::unique_ptr<DialectVersion> loadedVersion; 476 }; 477 478 /// This struct represents an operation name entry within the bytecode. 479 struct BytecodeOperationName { 480 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 481 bool wasRegistered) 482 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 483 484 /// The loaded operation name, or std::nullopt if it hasn't been processed 485 /// yet. 486 std::optional<OperationName> opName; 487 488 /// The dialect that owns this operation name. 489 BytecodeDialect *dialect; 490 491 /// The name of the operation, without the dialect prefix. 492 StringRef name; 493 494 /// Whether this operation was registered when the bytecode was produced. 495 /// This flag is populated when bytecode version >=4. 496 bool wasRegistered; 497 }; 498 } // namespace 499 500 /// Parse a single dialect group encoded in the byte stream. 501 static LogicalResult parseDialectGrouping( 502 EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, 503 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 504 // Parse the dialect and the number of entries in the group. 505 BytecodeDialect *dialect; 506 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 507 return failure(); 508 uint64_t numEntries; 509 if (failed(reader.parseVarInt(numEntries))) 510 return failure(); 511 512 for (uint64_t i = 0; i < numEntries; ++i) 513 if (failed(entryCallback(dialect))) 514 return failure(); 515 return success(); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // ResourceSectionReader 520 //===----------------------------------------------------------------------===// 521 522 namespace { 523 /// This class is used to read the resource section from the bytecode. 524 class ResourceSectionReader { 525 public: 526 /// Initialize the resource section reader with the given section data. 527 LogicalResult 528 initialize(Location fileLoc, const ParserConfig &config, 529 MutableArrayRef<BytecodeDialect> dialects, 530 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 531 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 532 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 533 534 /// Parse a dialect resource handle from the resource section. 535 LogicalResult parseResourceHandle(EncodingReader &reader, 536 AsmDialectResourceHandle &result) { 537 return parseEntry(reader, dialectResources, result, "resource handle"); 538 } 539 540 private: 541 /// The table of dialect resources within the bytecode file. 542 SmallVector<AsmDialectResourceHandle> dialectResources; 543 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 544 }; 545 546 class ParsedResourceEntry : public AsmParsedResourceEntry { 547 public: 548 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 549 EncodingReader &reader, StringSectionReader &stringReader, 550 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 551 : key(key), kind(kind), reader(reader), stringReader(stringReader), 552 bufferOwnerRef(bufferOwnerRef) {} 553 ~ParsedResourceEntry() override = default; 554 555 StringRef getKey() const final { return key; } 556 557 InFlightDiagnostic emitError() const final { return reader.emitError(); } 558 559 AsmResourceEntryKind getKind() const final { return kind; } 560 561 FailureOr<bool> parseAsBool() const final { 562 if (kind != AsmResourceEntryKind::Bool) 563 return emitError() << "expected a bool resource entry, but found a " 564 << toString(kind) << " entry instead"; 565 566 bool value; 567 if (failed(reader.parseByte(value))) 568 return failure(); 569 return value; 570 } 571 FailureOr<std::string> parseAsString() const final { 572 if (kind != AsmResourceEntryKind::String) 573 return emitError() << "expected a string resource entry, but found a " 574 << toString(kind) << " entry instead"; 575 576 StringRef string; 577 if (failed(stringReader.parseString(reader, string))) 578 return failure(); 579 return string.str(); 580 } 581 582 FailureOr<AsmResourceBlob> 583 parseAsBlob(BlobAllocatorFn allocator) const final { 584 if (kind != AsmResourceEntryKind::Blob) 585 return emitError() << "expected a blob resource entry, but found a " 586 << toString(kind) << " entry instead"; 587 588 ArrayRef<uint8_t> data; 589 uint64_t alignment; 590 if (failed(reader.parseBlobAndAlignment(data, alignment))) 591 return failure(); 592 593 // If we have an extendable reference to the buffer owner, we don't need to 594 // allocate a new buffer for the data, and can use the data directly. 595 if (bufferOwnerRef) { 596 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 597 data.size()); 598 599 // Allocate an unmanager buffer which captures a reference to the owner. 600 // For now we just mark this as immutable, but in the future we should 601 // explore marking this as mutable when desired. 602 return UnmanagedAsmResourceBlob::allocateWithAlign( 603 charData, alignment, 604 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 605 } 606 607 // Allocate memory for the blob using the provided allocator and copy the 608 // data into it. 609 AsmResourceBlob blob = allocator(data.size(), alignment); 610 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 611 blob.isMutable() && 612 "blob allocator did not return a properly aligned address"); 613 memcpy(blob.getMutableData().data(), data.data(), data.size()); 614 return blob; 615 } 616 617 private: 618 StringRef key; 619 AsmResourceEntryKind kind; 620 EncodingReader &reader; 621 StringSectionReader &stringReader; 622 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 623 }; 624 } // namespace 625 626 template <typename T> 627 static LogicalResult 628 parseResourceGroup(Location fileLoc, bool allowEmpty, 629 EncodingReader &offsetReader, EncodingReader &resourceReader, 630 StringSectionReader &stringReader, T *handler, 631 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 632 function_ref<StringRef(StringRef)> remapKey = {}, 633 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 634 uint64_t numResources; 635 if (failed(offsetReader.parseVarInt(numResources))) 636 return failure(); 637 638 for (uint64_t i = 0; i < numResources; ++i) { 639 StringRef key; 640 AsmResourceEntryKind kind; 641 uint64_t resourceOffset; 642 ArrayRef<uint8_t> data; 643 if (failed(stringReader.parseString(offsetReader, key)) || 644 failed(offsetReader.parseVarInt(resourceOffset)) || 645 failed(offsetReader.parseByte(kind)) || 646 failed(resourceReader.parseBytes(resourceOffset, data))) 647 return failure(); 648 649 // Process the resource key. 650 if ((processKeyFn && failed(processKeyFn(key)))) 651 return failure(); 652 653 // If the resource data is empty and we allow it, don't error out when 654 // parsing below, just skip it. 655 if (allowEmpty && data.empty()) 656 continue; 657 658 // Ignore the entry if we don't have a valid handler. 659 if (!handler) 660 continue; 661 662 // Otherwise, parse the resource value. 663 EncodingReader entryReader(data, fileLoc); 664 key = remapKey(key); 665 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 666 bufferOwnerRef); 667 if (failed(handler->parseResource(entry))) 668 return failure(); 669 if (!entryReader.empty()) { 670 return entryReader.emitError( 671 "unexpected trailing bytes in resource entry '", key, "'"); 672 } 673 } 674 return success(); 675 } 676 677 LogicalResult ResourceSectionReader::initialize( 678 Location fileLoc, const ParserConfig &config, 679 MutableArrayRef<BytecodeDialect> dialects, 680 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 681 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 682 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 683 EncodingReader resourceReader(sectionData, fileLoc); 684 EncodingReader offsetReader(offsetSectionData, fileLoc); 685 686 // Read the number of external resource providers. 687 uint64_t numExternalResourceGroups; 688 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 689 return failure(); 690 691 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 692 // provides most of the arguments. 693 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 694 function_ref<LogicalResult(StringRef)> keyFn = {}) { 695 auto resolveKey = [&](StringRef key) -> StringRef { 696 auto it = dialectResourceHandleRenamingMap.find(key); 697 if (it == dialectResourceHandleRenamingMap.end()) 698 return ""; 699 return it->second; 700 }; 701 702 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 703 stringReader, handler, bufferOwnerRef, resolveKey, 704 keyFn); 705 }; 706 707 // Read the external resources from the bytecode. 708 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 709 StringRef key; 710 if (failed(stringReader.parseString(offsetReader, key))) 711 return failure(); 712 713 // Get the handler for these resources. 714 // TODO: Should we require handling external resources in some scenarios? 715 AsmResourceParser *handler = config.getResourceParser(key); 716 if (!handler) { 717 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 718 << "'"; 719 } 720 721 if (failed(parseGroup(handler))) 722 return failure(); 723 } 724 725 // Read the dialect resources from the bytecode. 726 MLIRContext *ctx = fileLoc->getContext(); 727 while (!offsetReader.empty()) { 728 BytecodeDialect *dialect; 729 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 730 failed(dialect->load(dialectReader, ctx))) 731 return failure(); 732 Dialect *loadedDialect = dialect->getLoadedDialect(); 733 if (!loadedDialect) { 734 return resourceReader.emitError() 735 << "dialect '" << dialect->name << "' is unknown"; 736 } 737 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 738 if (!handler) { 739 return resourceReader.emitError() 740 << "unexpected resources for dialect '" << dialect->name << "'"; 741 } 742 743 // Ensure that each resource is declared before being processed. 744 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 745 FailureOr<AsmDialectResourceHandle> handle = 746 handler->declareResource(key); 747 if (failed(handle)) { 748 return resourceReader.emitError() 749 << "unknown 'resource' key '" << key << "' for dialect '" 750 << dialect->name << "'"; 751 } 752 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 753 dialectResources.push_back(*handle); 754 return success(); 755 }; 756 757 // Parse the resources for this dialect. We allow empty resources because we 758 // just treat these as declarations. 759 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 760 return failure(); 761 } 762 763 return success(); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // Attribute/Type Reader 768 //===----------------------------------------------------------------------===// 769 770 namespace { 771 /// This class provides support for reading attribute and type entries from the 772 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 773 /// this reader to manage when to actually parse them from the bytecode. 774 class AttrTypeReader { 775 /// This class represents a single attribute or type entry. 776 template <typename T> 777 struct Entry { 778 /// The entry, or null if it hasn't been resolved yet. 779 T entry = {}; 780 /// The parent dialect of this entry. 781 BytecodeDialect *dialect = nullptr; 782 /// A flag indicating if the entry was encoded using a custom encoding, 783 /// instead of using the textual assembly format. 784 bool hasCustomEncoding = false; 785 /// The raw data of this entry in the bytecode. 786 ArrayRef<uint8_t> data; 787 }; 788 using AttrEntry = Entry<Attribute>; 789 using TypeEntry = Entry<Type>; 790 791 public: 792 AttrTypeReader(StringSectionReader &stringReader, 793 ResourceSectionReader &resourceReader, Location fileLoc) 794 : stringReader(stringReader), resourceReader(resourceReader), 795 fileLoc(fileLoc) {} 796 797 /// Initialize the attribute and type information within the reader. 798 LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, 799 ArrayRef<uint8_t> sectionData, 800 ArrayRef<uint8_t> offsetSectionData); 801 802 /// Resolve the attribute or type at the given index. Returns nullptr on 803 /// failure. 804 Attribute resolveAttribute(size_t index) { 805 return resolveEntry(attributes, index, "Attribute"); 806 } 807 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 808 809 /// Parse a reference to an attribute or type using the given reader. 810 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 811 uint64_t attrIdx; 812 if (failed(reader.parseVarInt(attrIdx))) 813 return failure(); 814 result = resolveAttribute(attrIdx); 815 return success(!!result); 816 } 817 LogicalResult parseOptionalAttribute(EncodingReader &reader, 818 Attribute &result) { 819 uint64_t attrIdx; 820 bool flag; 821 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 822 return failure(); 823 if (!flag) 824 return success(); 825 result = resolveAttribute(attrIdx); 826 return success(!!result); 827 } 828 829 LogicalResult parseType(EncodingReader &reader, Type &result) { 830 uint64_t typeIdx; 831 if (failed(reader.parseVarInt(typeIdx))) 832 return failure(); 833 result = resolveType(typeIdx); 834 return success(!!result); 835 } 836 837 template <typename T> 838 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 839 Attribute baseResult; 840 if (failed(parseAttribute(reader, baseResult))) 841 return failure(); 842 if ((result = dyn_cast<T>(baseResult))) 843 return success(); 844 return reader.emitError("expected attribute of type: ", 845 llvm::getTypeName<T>(), ", but got: ", baseResult); 846 } 847 848 private: 849 /// Resolve the given entry at `index`. 850 template <typename T> 851 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 852 StringRef entryType); 853 854 /// Parse an entry using the given reader that was encoded using the textual 855 /// assembly format. 856 template <typename T> 857 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 858 StringRef entryType); 859 860 /// Parse an entry using the given reader that was encoded using a custom 861 /// bytecode format. 862 template <typename T> 863 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 864 StringRef entryType); 865 866 /// The string section reader used to resolve string references when parsing 867 /// custom encoded attribute/type entries. 868 StringSectionReader &stringReader; 869 870 /// The resource section reader used to resolve resource references when 871 /// parsing custom encoded attribute/type entries. 872 ResourceSectionReader &resourceReader; 873 874 /// The set of attribute and type entries. 875 SmallVector<AttrEntry> attributes; 876 SmallVector<TypeEntry> types; 877 878 /// A location used for error emission. 879 Location fileLoc; 880 }; 881 882 class DialectReader : public DialectBytecodeReader { 883 public: 884 DialectReader(AttrTypeReader &attrTypeReader, 885 StringSectionReader &stringReader, 886 ResourceSectionReader &resourceReader, EncodingReader &reader) 887 : attrTypeReader(attrTypeReader), stringReader(stringReader), 888 resourceReader(resourceReader), reader(reader) {} 889 890 InFlightDiagnostic emitError(const Twine &msg) override { 891 return reader.emitError(msg); 892 } 893 894 DialectReader withEncodingReader(EncodingReader &encReader) { 895 return DialectReader(attrTypeReader, stringReader, resourceReader, 896 encReader); 897 } 898 899 Location getLoc() const { return reader.getLoc(); } 900 901 //===--------------------------------------------------------------------===// 902 // IR 903 //===--------------------------------------------------------------------===// 904 905 LogicalResult readAttribute(Attribute &result) override { 906 return attrTypeReader.parseAttribute(reader, result); 907 } 908 LogicalResult readOptionalAttribute(Attribute &result) override { 909 return attrTypeReader.parseOptionalAttribute(reader, result); 910 } 911 LogicalResult readType(Type &result) override { 912 return attrTypeReader.parseType(reader, result); 913 } 914 915 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 916 AsmDialectResourceHandle handle; 917 if (failed(resourceReader.parseResourceHandle(reader, handle))) 918 return failure(); 919 return handle; 920 } 921 922 //===--------------------------------------------------------------------===// 923 // Primitives 924 //===--------------------------------------------------------------------===// 925 926 LogicalResult readVarInt(uint64_t &result) override { 927 return reader.parseVarInt(result); 928 } 929 930 LogicalResult readSignedVarInt(int64_t &result) override { 931 uint64_t unsignedResult; 932 if (failed(reader.parseSignedVarInt(unsignedResult))) 933 return failure(); 934 result = static_cast<int64_t>(unsignedResult); 935 return success(); 936 } 937 938 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 939 // Small values are encoded using a single byte. 940 if (bitWidth <= 8) { 941 uint8_t value; 942 if (failed(reader.parseByte(value))) 943 return failure(); 944 return APInt(bitWidth, value); 945 } 946 947 // Large values up to 64 bits are encoded using a single varint. 948 if (bitWidth <= 64) { 949 uint64_t value; 950 if (failed(reader.parseSignedVarInt(value))) 951 return failure(); 952 return APInt(bitWidth, value); 953 } 954 955 // Otherwise, for really big values we encode the array of active words in 956 // the value. 957 uint64_t numActiveWords; 958 if (failed(reader.parseVarInt(numActiveWords))) 959 return failure(); 960 SmallVector<uint64_t, 4> words(numActiveWords); 961 for (uint64_t i = 0; i < numActiveWords; ++i) 962 if (failed(reader.parseSignedVarInt(words[i]))) 963 return failure(); 964 return APInt(bitWidth, words); 965 } 966 967 FailureOr<APFloat> 968 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 969 FailureOr<APInt> intVal = 970 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 971 if (failed(intVal)) 972 return failure(); 973 return APFloat(semantics, *intVal); 974 } 975 976 LogicalResult readString(StringRef &result) override { 977 return stringReader.parseString(reader, result); 978 } 979 980 LogicalResult readBlob(ArrayRef<char> &result) override { 981 uint64_t dataSize; 982 ArrayRef<uint8_t> data; 983 if (failed(reader.parseVarInt(dataSize)) || 984 failed(reader.parseBytes(dataSize, data))) 985 return failure(); 986 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 987 data.size()); 988 return success(); 989 } 990 991 private: 992 AttrTypeReader &attrTypeReader; 993 StringSectionReader &stringReader; 994 ResourceSectionReader &resourceReader; 995 EncodingReader &reader; 996 }; 997 998 /// Wraps the properties section and handles reading properties out of it. 999 class PropertiesSectionReader { 1000 public: 1001 /// Initialize the properties section reader with the given section data. 1002 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1003 if (sectionData.empty()) 1004 return success(); 1005 EncodingReader propReader(sectionData, fileLoc); 1006 uint64_t count; 1007 if (failed(propReader.parseVarInt(count))) 1008 return failure(); 1009 // Parse the raw properties buffer. 1010 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1011 return failure(); 1012 1013 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1014 offsetTable.reserve(count); 1015 for (auto idx : llvm::seq<int64_t>(0, count)) { 1016 (void)idx; 1017 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1018 ArrayRef<uint8_t> rawProperties; 1019 uint64_t dataSize; 1020 if (failed(offsetsReader.parseVarInt(dataSize)) || 1021 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1022 return failure(); 1023 } 1024 if (!offsetsReader.empty()) 1025 return offsetsReader.emitError() 1026 << "Broken properties section: didn't exhaust the offsets table"; 1027 return success(); 1028 } 1029 1030 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1031 OperationName *opName, OperationState &opState) { 1032 uint64_t propertiesIdx; 1033 if (failed(dialectReader.readVarInt(propertiesIdx))) 1034 return failure(); 1035 if (propertiesIdx >= offsetTable.size()) 1036 return dialectReader.emitError("Properties idx out-of-bound for ") 1037 << opName->getStringRef(); 1038 size_t propertiesOffset = offsetTable[propertiesIdx]; 1039 if (propertiesIdx >= propertiesBuffers.size()) 1040 return dialectReader.emitError("Properties offset out-of-bound for ") 1041 << opName->getStringRef(); 1042 1043 // Acquire the sub-buffer that represent the requested properties. 1044 ArrayRef<char> rawProperties; 1045 { 1046 // "Seek" to the requested offset by getting a new reader with the right 1047 // sub-buffer. 1048 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1049 fileLoc); 1050 // Properties are stored as a sequence of {size + raw_data}. 1051 if (failed( 1052 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1053 return failure(); 1054 } 1055 // Setup a new reader to read from the `rawProperties` sub-buffer. 1056 EncodingReader reader( 1057 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1058 DialectReader propReader = dialectReader.withEncodingReader(reader); 1059 1060 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1061 if (iface) 1062 return iface->readProperties(propReader, opState); 1063 if (opName->isRegistered()) 1064 return propReader.emitError( 1065 "has properties but missing BytecodeOpInterface for ") 1066 << opName->getStringRef(); 1067 // Unregistered op are storing properties as an attribute. 1068 return propReader.readAttribute(opState.propertiesAttr); 1069 } 1070 1071 private: 1072 /// The properties buffer referenced within the bytecode file. 1073 ArrayRef<uint8_t> propertiesBuffers; 1074 1075 /// Table of offset in the buffer above. 1076 SmallVector<int64_t> offsetTable; 1077 }; 1078 } // namespace 1079 1080 LogicalResult 1081 AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, 1082 ArrayRef<uint8_t> sectionData, 1083 ArrayRef<uint8_t> offsetSectionData) { 1084 EncodingReader offsetReader(offsetSectionData, fileLoc); 1085 1086 // Parse the number of attribute and type entries. 1087 uint64_t numAttributes, numTypes; 1088 if (failed(offsetReader.parseVarInt(numAttributes)) || 1089 failed(offsetReader.parseVarInt(numTypes))) 1090 return failure(); 1091 attributes.resize(numAttributes); 1092 types.resize(numTypes); 1093 1094 // A functor used to accumulate the offsets for the entries in the given 1095 // range. 1096 uint64_t currentOffset = 0; 1097 auto parseEntries = [&](auto &&range) { 1098 size_t currentIndex = 0, endIndex = range.size(); 1099 1100 // Parse an individual entry. 1101 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1102 auto &entry = range[currentIndex++]; 1103 1104 uint64_t entrySize; 1105 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1106 entry.hasCustomEncoding))) 1107 return failure(); 1108 1109 // Verify that the offset is actually valid. 1110 if (currentOffset + entrySize > sectionData.size()) { 1111 return offsetReader.emitError( 1112 "Attribute or Type entry offset points past the end of section"); 1113 } 1114 1115 entry.data = sectionData.slice(currentOffset, entrySize); 1116 entry.dialect = dialect; 1117 currentOffset += entrySize; 1118 return success(); 1119 }; 1120 while (currentIndex != endIndex) 1121 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1122 return failure(); 1123 return success(); 1124 }; 1125 1126 // Process each of the attributes, and then the types. 1127 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1128 return failure(); 1129 1130 // Ensure that we read everything from the section. 1131 if (!offsetReader.empty()) { 1132 return offsetReader.emitError( 1133 "unexpected trailing data in the Attribute/Type offset section"); 1134 } 1135 return success(); 1136 } 1137 1138 template <typename T> 1139 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1140 StringRef entryType) { 1141 if (index >= entries.size()) { 1142 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1143 return {}; 1144 } 1145 1146 // If the entry has already been resolved, there is nothing left to do. 1147 Entry<T> &entry = entries[index]; 1148 if (entry.entry) 1149 return entry.entry; 1150 1151 // Parse the entry. 1152 EncodingReader reader(entry.data, fileLoc); 1153 1154 // Parse based on how the entry was encoded. 1155 if (entry.hasCustomEncoding) { 1156 if (failed(parseCustomEntry(entry, reader, entryType))) 1157 return T(); 1158 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1159 return T(); 1160 } 1161 1162 if (!reader.empty()) { 1163 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1164 return T(); 1165 } 1166 return entry.entry; 1167 } 1168 1169 template <typename T> 1170 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1171 StringRef entryType) { 1172 StringRef asmStr; 1173 if (failed(reader.parseNullTerminatedString(asmStr))) 1174 return failure(); 1175 1176 // Invoke the MLIR assembly parser to parse the entry text. 1177 size_t numRead = 0; 1178 MLIRContext *context = fileLoc->getContext(); 1179 if constexpr (std::is_same_v<T, Type>) 1180 result = 1181 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1182 else 1183 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1184 /*isKnownNullTerminated=*/true); 1185 if (!result) 1186 return failure(); 1187 1188 // Ensure there weren't dangling characters after the entry. 1189 if (numRead != asmStr.size()) { 1190 return reader.emitError("trailing characters found after ", entryType, 1191 " assembly format: ", asmStr.drop_front(numRead)); 1192 } 1193 return success(); 1194 } 1195 1196 template <typename T> 1197 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1198 EncodingReader &reader, 1199 StringRef entryType) { 1200 DialectReader dialectReader(*this, stringReader, resourceReader, reader); 1201 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1202 return failure(); 1203 // Ensure that the dialect implements the bytecode interface. 1204 if (!entry.dialect->interface) { 1205 return reader.emitError("dialect '", entry.dialect->name, 1206 "' does not implement the bytecode interface"); 1207 } 1208 1209 // Ask the dialect to parse the entry. If the dialect is versioned, parse 1210 // using the versioned encoding readers. 1211 if (entry.dialect->loadedVersion.get()) { 1212 if constexpr (std::is_same_v<T, Type>) 1213 entry.entry = entry.dialect->interface->readType( 1214 dialectReader, *entry.dialect->loadedVersion); 1215 else 1216 entry.entry = entry.dialect->interface->readAttribute( 1217 dialectReader, *entry.dialect->loadedVersion); 1218 1219 } else { 1220 if constexpr (std::is_same_v<T, Type>) 1221 entry.entry = entry.dialect->interface->readType(dialectReader); 1222 else 1223 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1224 } 1225 return success(!!entry.entry); 1226 } 1227 1228 //===----------------------------------------------------------------------===// 1229 // Bytecode Reader 1230 //===----------------------------------------------------------------------===// 1231 1232 /// This class is used to read a bytecode buffer and translate it into MLIR. 1233 class mlir::BytecodeReader::Impl { 1234 struct RegionReadState; 1235 using LazyLoadableOpsInfo = 1236 std::list<std::pair<Operation *, RegionReadState>>; 1237 using LazyLoadableOpsMap = 1238 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1239 1240 public: 1241 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1242 llvm::MemoryBufferRef buffer, 1243 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1244 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1245 attrTypeReader(stringReader, resourceReader, fileLoc), 1246 // Use the builtin unrealized conversion cast operation to represent 1247 // forward references to values that aren't yet defined. 1248 forwardRefOpState(UnknownLoc::get(config.getContext()), 1249 "builtin.unrealized_conversion_cast", ValueRange(), 1250 NoneType::get(config.getContext())), 1251 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1252 1253 /// Read the bytecode defined within `buffer` into the given block. 1254 LogicalResult read(Block *block, 1255 llvm::function_ref<bool(Operation *)> lazyOps); 1256 1257 /// Return the number of ops that haven't been materialized yet. 1258 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1259 1260 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1261 1262 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1263 /// newly found lazy operation. 1264 LogicalResult 1265 materialize(Operation *op, 1266 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1267 this->lazyOpsCallback = lazyOpsCallback; 1268 auto resetlazyOpsCallback = 1269 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1270 auto it = lazyLoadableOpsMap.find(op); 1271 assert(it != lazyLoadableOpsMap.end() && 1272 "materialize called on non-materializable op"); 1273 return materialize(it); 1274 } 1275 1276 /// Materialize all operations. 1277 LogicalResult materializeAll() { 1278 while (!lazyLoadableOpsMap.empty()) { 1279 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1280 return failure(); 1281 } 1282 return success(); 1283 } 1284 1285 /// Finalize the lazy-loading by calling back with every op that hasn't been 1286 /// materialized to let the client decide if the op should be deleted or 1287 /// materialized. The op is materialized if the callback returns true, deleted 1288 /// otherwise. 1289 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1290 while (!lazyLoadableOps.empty()) { 1291 Operation *op = lazyLoadableOps.begin()->first; 1292 if (shouldMaterialize(op)) { 1293 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1294 return failure(); 1295 continue; 1296 } 1297 op->dropAllReferences(); 1298 op->erase(); 1299 lazyLoadableOps.pop_front(); 1300 lazyLoadableOpsMap.erase(op); 1301 } 1302 return success(); 1303 } 1304 1305 private: 1306 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1307 assert(it != lazyLoadableOpsMap.end() && 1308 "materialize called on non-materializable op"); 1309 valueScopes.emplace_back(); 1310 std::vector<RegionReadState> regionStack; 1311 regionStack.push_back(std::move(it->getSecond()->second)); 1312 lazyLoadableOps.erase(it->getSecond()); 1313 lazyLoadableOpsMap.erase(it); 1314 auto result = parseRegions(regionStack, regionStack.back()); 1315 assert((regionStack.empty() || failed(result)) && 1316 "broken invariant: regionStack should be empty when parseRegions " 1317 "succeeds"); 1318 return result; 1319 } 1320 1321 /// Return the context for this config. 1322 MLIRContext *getContext() const { return config.getContext(); } 1323 1324 /// Parse the bytecode version. 1325 LogicalResult parseVersion(EncodingReader &reader); 1326 1327 //===--------------------------------------------------------------------===// 1328 // Dialect Section 1329 1330 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1331 1332 /// Parse an operation name reference using the given reader. 1333 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1334 bool &wasRegistered); 1335 1336 //===--------------------------------------------------------------------===// 1337 // Attribute/Type Section 1338 1339 /// Parse an attribute or type using the given reader. 1340 template <typename T> 1341 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1342 return attrTypeReader.parseAttribute(reader, result); 1343 } 1344 LogicalResult parseType(EncodingReader &reader, Type &result) { 1345 return attrTypeReader.parseType(reader, result); 1346 } 1347 1348 //===--------------------------------------------------------------------===// 1349 // Resource Section 1350 1351 LogicalResult 1352 parseResourceSection(EncodingReader &reader, 1353 std::optional<ArrayRef<uint8_t>> resourceData, 1354 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1355 1356 //===--------------------------------------------------------------------===// 1357 // IR Section 1358 1359 /// This struct represents the current read state of a range of regions. This 1360 /// struct is used to enable iterative parsing of regions. 1361 struct RegionReadState { 1362 RegionReadState(Operation *op, EncodingReader *reader, 1363 bool isIsolatedFromAbove) 1364 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1365 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1366 bool isIsolatedFromAbove) 1367 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1368 isIsolatedFromAbove(isIsolatedFromAbove) {} 1369 1370 /// The current regions being read. 1371 MutableArrayRef<Region>::iterator curRegion, endRegion; 1372 /// This is the reader to use for this region, this pointer is pointing to 1373 /// the parent region reader unless the current region is IsolatedFromAbove, 1374 /// in which case the pointer is pointing to the `owningReader` which is a 1375 /// section dedicated to the current region. 1376 EncodingReader *reader; 1377 std::unique_ptr<EncodingReader> owningReader; 1378 1379 /// The number of values defined immediately within this region. 1380 unsigned numValues = 0; 1381 1382 /// The current blocks of the region being read. 1383 SmallVector<Block *> curBlocks; 1384 Region::iterator curBlock = {}; 1385 1386 /// The number of operations remaining to be read from the current block 1387 /// being read. 1388 uint64_t numOpsRemaining = 0; 1389 1390 /// A flag indicating if the regions being read are isolated from above. 1391 bool isIsolatedFromAbove = false; 1392 }; 1393 1394 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1395 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1396 RegionReadState &readState); 1397 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1398 RegionReadState &readState, 1399 bool &isIsolatedFromAbove); 1400 1401 LogicalResult parseRegion(RegionReadState &readState); 1402 LogicalResult parseBlockHeader(EncodingReader &reader, 1403 RegionReadState &readState); 1404 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1405 1406 //===--------------------------------------------------------------------===// 1407 // Value Processing 1408 1409 /// Parse an operand reference using the given reader. Returns nullptr in the 1410 /// case of failure. 1411 Value parseOperand(EncodingReader &reader); 1412 1413 /// Sequentially define the given value range. 1414 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1415 1416 /// Create a value to use for a forward reference. 1417 Value createForwardRef(); 1418 1419 //===--------------------------------------------------------------------===// 1420 // Use-list order helpers 1421 1422 /// This struct is a simple storage that contains information required to 1423 /// reorder the use-list of a value with respect to the pre-order traversal 1424 /// ordering. 1425 struct UseListOrderStorage { 1426 UseListOrderStorage(bool isIndexPairEncoding, 1427 SmallVector<unsigned, 4> &&indices) 1428 : indices(std::move(indices)), 1429 isIndexPairEncoding(isIndexPairEncoding){}; 1430 /// The vector containing the information required to reorder the 1431 /// use-list of a value. 1432 SmallVector<unsigned, 4> indices; 1433 1434 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1435 /// indexing, such as `dst = order[src]`. 1436 bool isIndexPairEncoding; 1437 }; 1438 1439 /// Parse use-list order from bytecode for a range of values if available. The 1440 /// range is expected to be either a block argument or an op result range. On 1441 /// success, return a map of the position in the range and the use-list order 1442 /// encoding. The function assumes to know the size of the range it is 1443 /// processing. 1444 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1445 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1446 uint64_t rangeSize); 1447 1448 /// Shuffle the use-chain according to the order parsed. 1449 LogicalResult sortUseListOrder(Value value); 1450 1451 /// Recursively visit all the values defined within topLevelOp and sort the 1452 /// use-list orders according to the indices parsed. 1453 LogicalResult processUseLists(Operation *topLevelOp); 1454 1455 //===--------------------------------------------------------------------===// 1456 // Fields 1457 1458 /// This class represents a single value scope, in which a value scope is 1459 /// delimited by isolated from above regions. 1460 struct ValueScope { 1461 /// Push a new region state onto this scope, reserving enough values for 1462 /// those defined within the current region of the provided state. 1463 void push(RegionReadState &readState) { 1464 nextValueIDs.push_back(values.size()); 1465 values.resize(values.size() + readState.numValues); 1466 } 1467 1468 /// Pop the values defined for the current region within the provided region 1469 /// state. 1470 void pop(RegionReadState &readState) { 1471 values.resize(values.size() - readState.numValues); 1472 nextValueIDs.pop_back(); 1473 } 1474 1475 /// The set of values defined in this scope. 1476 std::vector<Value> values; 1477 1478 /// The ID for the next defined value for each region current being 1479 /// processed in this scope. 1480 SmallVector<unsigned, 4> nextValueIDs; 1481 }; 1482 1483 /// The configuration of the parser. 1484 const ParserConfig &config; 1485 1486 /// A location to use when emitting errors. 1487 Location fileLoc; 1488 1489 /// Flag that indicates if lazyloading is enabled. 1490 bool lazyLoading; 1491 1492 /// Keep track of operations that have been lazy loaded (their regions haven't 1493 /// been materialized), along with the `RegionReadState` that allows to 1494 /// lazy-load the regions nested under the operation. 1495 LazyLoadableOpsInfo lazyLoadableOps; 1496 LazyLoadableOpsMap lazyLoadableOpsMap; 1497 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1498 1499 /// The reader used to process attribute and types within the bytecode. 1500 AttrTypeReader attrTypeReader; 1501 1502 /// The version of the bytecode being read. 1503 uint64_t version = 0; 1504 1505 /// The producer of the bytecode being read. 1506 StringRef producer; 1507 1508 /// The table of IR units referenced within the bytecode file. 1509 SmallVector<BytecodeDialect> dialects; 1510 SmallVector<BytecodeOperationName> opNames; 1511 1512 /// The reader used to process resources within the bytecode. 1513 ResourceSectionReader resourceReader; 1514 1515 /// Worklist of values with custom use-list orders to process before the end 1516 /// of the parsing. 1517 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1518 1519 /// The table of strings referenced within the bytecode file. 1520 StringSectionReader stringReader; 1521 1522 /// The table of properties referenced by the operation in the bytecode file. 1523 PropertiesSectionReader propertiesReader; 1524 1525 /// The current set of available IR value scopes. 1526 std::vector<ValueScope> valueScopes; 1527 1528 /// The global pre-order operation ordering. 1529 DenseMap<Operation *, unsigned> operationIDs; 1530 1531 /// A block containing the set of operations defined to create forward 1532 /// references. 1533 Block forwardRefOps; 1534 1535 /// A block containing previously created, and no longer used, forward 1536 /// reference operations. 1537 Block openForwardRefOps; 1538 1539 /// An operation state used when instantiating forward references. 1540 OperationState forwardRefOpState; 1541 1542 /// Reference to the input buffer. 1543 llvm::MemoryBufferRef buffer; 1544 1545 /// The optional owning source manager, which when present may be used to 1546 /// extend the lifetime of the input buffer. 1547 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1548 }; 1549 1550 LogicalResult BytecodeReader::Impl::read( 1551 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1552 EncodingReader reader(buffer.getBuffer(), fileLoc); 1553 this->lazyOpsCallback = lazyOpsCallback; 1554 auto resetlazyOpsCallback = 1555 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1556 1557 // Skip over the bytecode header, this should have already been checked. 1558 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1559 return failure(); 1560 // Parse the bytecode version and producer. 1561 if (failed(parseVersion(reader)) || 1562 failed(reader.parseNullTerminatedString(producer))) 1563 return failure(); 1564 1565 // Add a diagnostic handler that attaches a note that includes the original 1566 // producer of the bytecode. 1567 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1568 diag.attachNote() << "in bytecode version " << version 1569 << " produced by: " << producer; 1570 return failure(); 1571 }); 1572 1573 // Parse the raw data for each of the top-level sections of the bytecode. 1574 std::optional<ArrayRef<uint8_t>> 1575 sectionDatas[bytecode::Section::kNumSections]; 1576 while (!reader.empty()) { 1577 // Read the next section from the bytecode. 1578 bytecode::Section::ID sectionID; 1579 ArrayRef<uint8_t> sectionData; 1580 if (failed(reader.parseSection(sectionID, sectionData))) 1581 return failure(); 1582 1583 // Check for duplicate sections, we only expect one instance of each. 1584 if (sectionDatas[sectionID]) { 1585 return reader.emitError("duplicate top-level section: ", 1586 ::toString(sectionID)); 1587 } 1588 sectionDatas[sectionID] = sectionData; 1589 } 1590 // Check that all of the required sections were found. 1591 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1592 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1593 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1594 return reader.emitError("missing data for top-level section: ", 1595 ::toString(sectionID)); 1596 } 1597 } 1598 1599 // Process the string section first. 1600 if (failed(stringReader.initialize( 1601 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1602 return failure(); 1603 1604 // Process the properties section. 1605 if (sectionDatas[bytecode::Section::kProperties] && 1606 failed(propertiesReader.initialize( 1607 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1608 return failure(); 1609 1610 // Process the dialect section. 1611 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1612 return failure(); 1613 1614 // Process the resource section if present. 1615 if (failed(parseResourceSection( 1616 reader, sectionDatas[bytecode::Section::kResource], 1617 sectionDatas[bytecode::Section::kResourceOffset]))) 1618 return failure(); 1619 1620 // Process the attribute and type section. 1621 if (failed(attrTypeReader.initialize( 1622 dialects, *sectionDatas[bytecode::Section::kAttrType], 1623 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1624 return failure(); 1625 1626 // Finally, process the IR section. 1627 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1628 } 1629 1630 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1631 if (failed(reader.parseVarInt(version))) 1632 return failure(); 1633 1634 // Validate the bytecode version. 1635 uint64_t currentVersion = bytecode::kVersion; 1636 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1637 if (version < minSupportedVersion) { 1638 return reader.emitError("bytecode version ", version, 1639 " is older than the current version of ", 1640 currentVersion, ", and upgrade is not supported"); 1641 } 1642 if (version > currentVersion) { 1643 return reader.emitError("bytecode version ", version, 1644 " is newer than the current version ", 1645 currentVersion); 1646 } 1647 // Override any request to lazy-load if the bytecode version is too old. 1648 if (version < 2) 1649 lazyLoading = false; 1650 return success(); 1651 } 1652 1653 //===----------------------------------------------------------------------===// 1654 // Dialect Section 1655 1656 LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { 1657 if (dialect) 1658 return success(); 1659 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1660 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1661 return reader.emitError("dialect '") 1662 << name 1663 << "' is unknown. If this is intended, please call " 1664 "allowUnregisteredDialects() on the MLIRContext, or use " 1665 "-allow-unregistered-dialect with the MLIR tool used."; 1666 } 1667 dialect = loadedDialect; 1668 1669 // If the dialect was actually loaded, check to see if it has a bytecode 1670 // interface. 1671 if (loadedDialect) 1672 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1673 if (!versionBuffer.empty()) { 1674 if (!interface) 1675 return reader.emitError("dialect '") 1676 << name 1677 << "' does not implement the bytecode interface, " 1678 "but found a version entry"; 1679 EncodingReader encReader(versionBuffer, reader.getLoc()); 1680 DialectReader versionReader = reader.withEncodingReader(encReader); 1681 loadedVersion = interface->readVersion(versionReader); 1682 if (!loadedVersion) 1683 return failure(); 1684 } 1685 return success(); 1686 } 1687 1688 LogicalResult 1689 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1690 EncodingReader sectionReader(sectionData, fileLoc); 1691 1692 // Parse the number of dialects in the section. 1693 uint64_t numDialects; 1694 if (failed(sectionReader.parseVarInt(numDialects))) 1695 return failure(); 1696 dialects.resize(numDialects); 1697 1698 // Parse each of the dialects. 1699 for (uint64_t i = 0; i < numDialects; ++i) { 1700 /// Before version 1, there wasn't any versioning available for dialects, 1701 /// and the entryIdx represent the string itself. 1702 if (version == 0) { 1703 if (failed(stringReader.parseString(sectionReader, dialects[i].name))) 1704 return failure(); 1705 continue; 1706 } 1707 // Parse ID representing dialect and version. 1708 uint64_t dialectNameIdx; 1709 bool versionAvailable; 1710 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1711 versionAvailable))) 1712 return failure(); 1713 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1714 dialects[i].name))) 1715 return failure(); 1716 if (versionAvailable) { 1717 bytecode::Section::ID sectionID; 1718 if (failed( 1719 sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) 1720 return failure(); 1721 if (sectionID != bytecode::Section::kDialectVersions) { 1722 emitError(fileLoc, "expected dialect version section"); 1723 return failure(); 1724 } 1725 } 1726 } 1727 1728 // Parse the operation names, which are grouped by dialect. 1729 auto parseOpName = [&](BytecodeDialect *dialect) { 1730 StringRef opName; 1731 bool wasRegistered; 1732 // Prior to version 4, the information about wheter an op was registered or 1733 // not wasn't encoded. 1734 if (version < 4) { 1735 if (failed(stringReader.parseString(sectionReader, opName))) 1736 return failure(); 1737 } else { 1738 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1739 wasRegistered))) 1740 return failure(); 1741 } 1742 opNames.emplace_back(dialect, opName, wasRegistered); 1743 return success(); 1744 }; 1745 // Avoid re-allocation in bytecode version > 3 where the number of ops are 1746 // known. 1747 if (version > 3) { 1748 uint64_t numOps; 1749 if (failed(sectionReader.parseVarInt(numOps))) 1750 return failure(); 1751 opNames.reserve(numOps); 1752 } 1753 while (!sectionReader.empty()) 1754 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1755 return failure(); 1756 return success(); 1757 } 1758 1759 FailureOr<OperationName> 1760 BytecodeReader::Impl::parseOpName(EncodingReader &reader, bool &wasRegistered) { 1761 BytecodeOperationName *opName = nullptr; 1762 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1763 return failure(); 1764 wasRegistered = opName->wasRegistered; 1765 // Check to see if this operation name has already been resolved. If we 1766 // haven't, load the dialect and build the operation name. 1767 if (!opName->opName) { 1768 // Load the dialect and its version. 1769 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1770 reader); 1771 if (failed(opName->dialect->load(dialectReader, getContext()))) 1772 return failure(); 1773 // If the opName is empty, this is because we use to accept names such as 1774 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1775 // format anymore but for now we'll be backward compatible. This can only 1776 // happen with unregistered dialects. 1777 if (opName->name.empty()) { 1778 if (opName->dialect->getLoadedDialect()) 1779 return emitError(fileLoc) << "has an empty opname for dialect '" 1780 << opName->dialect->name << "'\n"; 1781 1782 opName->opName.emplace(opName->dialect->name, getContext()); 1783 } else { 1784 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1785 getContext()); 1786 } 1787 } 1788 return *opName->opName; 1789 } 1790 1791 //===----------------------------------------------------------------------===// 1792 // Resource Section 1793 1794 LogicalResult BytecodeReader::Impl::parseResourceSection( 1795 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1796 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1797 // Ensure both sections are either present or not. 1798 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1799 if (resourceOffsetData) 1800 return emitError(fileLoc, "unexpected resource offset section when " 1801 "resource section is not present"); 1802 return emitError( 1803 fileLoc, 1804 "expected resource offset section when resource section is present"); 1805 } 1806 1807 // If the resource sections are absent, there is nothing to do. 1808 if (!resourceData) 1809 return success(); 1810 1811 // Initialize the resource reader with the resource sections. 1812 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1813 reader); 1814 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1815 *resourceData, *resourceOffsetData, 1816 dialectReader, bufferOwnerRef); 1817 } 1818 1819 //===----------------------------------------------------------------------===// 1820 // UseListOrder Helpers 1821 1822 FailureOr<BytecodeReader::Impl::UseListMapT> 1823 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1824 uint64_t numResults) { 1825 BytecodeReader::Impl::UseListMapT map; 1826 uint64_t numValuesToRead = 1; 1827 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1828 return failure(); 1829 1830 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1831 uint64_t resultIdx = 0; 1832 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1833 return failure(); 1834 1835 uint64_t numValues; 1836 bool indexPairEncoding; 1837 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1838 return failure(); 1839 1840 SmallVector<unsigned, 4> useListOrders; 1841 for (size_t idx = 0; idx < numValues; idx++) { 1842 uint64_t index; 1843 if (failed(reader.parseVarInt(index))) 1844 return failure(); 1845 useListOrders.push_back(index); 1846 } 1847 1848 // Store in a map the result index 1849 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1850 std::move(useListOrders))); 1851 } 1852 1853 return map; 1854 } 1855 1856 /// Sorts each use according to the order specified in the use-list parsed. If 1857 /// the custom use-list is not found, this means that the order needs to be 1858 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1859 /// on the same operation, the order will follow the reverse operand number 1860 /// ordering. 1861 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1862 // Early return for trivial use-lists. 1863 if (value.use_empty() || value.hasOneUse()) 1864 return success(); 1865 1866 bool hasIncomingOrder = 1867 valueToUseListMap.contains(value.getAsOpaquePointer()); 1868 1869 // Compute the current order of the use-list with respect to the global 1870 // ordering. Detect if the order is already sorted while doing so. 1871 bool alreadySorted = true; 1872 auto &firstUse = *value.use_begin(); 1873 uint64_t prevID = 1874 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1875 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1876 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1877 uint64_t currentID = bytecode::getUseID( 1878 item.value(), operationIDs.at(item.value().getOwner())); 1879 alreadySorted &= prevID > currentID; 1880 currentOrder.push_back({item.index(), currentID}); 1881 prevID = currentID; 1882 } 1883 1884 // If the order is already sorted, and there wasn't a custom order to apply 1885 // from the bytecode file, we are done. 1886 if (alreadySorted && !hasIncomingOrder) 1887 return success(); 1888 1889 // If not already sorted, sort the indices of the current order by descending 1890 // useIDs. 1891 if (!alreadySorted) 1892 std::sort( 1893 currentOrder.begin(), currentOrder.end(), 1894 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1895 1896 if (!hasIncomingOrder) { 1897 // If the bytecode file did not contain any custom use-list order, it means 1898 // that the order was descending useID. Hence, shuffle by the first index 1899 // of the `currentOrder` pair. 1900 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1901 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1902 value.shuffleUseList(shuffle); 1903 return success(); 1904 } 1905 1906 // Pull the custom order info from the map. 1907 UseListOrderStorage customOrder = 1908 valueToUseListMap.at(value.getAsOpaquePointer()); 1909 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1910 uint64_t numUses = 1911 std::distance(value.getUses().begin(), value.getUses().end()); 1912 1913 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1914 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1915 // as identity, and then apply the mapping encoded in the indices. 1916 if (customOrder.isIndexPairEncoding) { 1917 // Return failure if the number of indices was not representing pairs. 1918 if (shuffle.size() & 1) 1919 return failure(); 1920 1921 SmallVector<unsigned, 4> newShuffle(numUses); 1922 size_t idx = 0; 1923 std::iota(newShuffle.begin(), newShuffle.end(), idx); 1924 for (idx = 0; idx < shuffle.size(); idx += 2) 1925 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 1926 1927 shuffle = std::move(newShuffle); 1928 } 1929 1930 // Make sure that the indices represent a valid mapping. That is, the sum of 1931 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 1932 // duplicates are allowed in the list. 1933 DenseSet<unsigned> set; 1934 uint64_t accumulator = 0; 1935 for (const auto &elem : shuffle) { 1936 if (set.contains(elem)) 1937 return failure(); 1938 accumulator += elem; 1939 set.insert(elem); 1940 } 1941 if (numUses != shuffle.size() || 1942 accumulator != (((numUses - 1) * numUses) >> 1)) 1943 return failure(); 1944 1945 // Apply the current ordering map onto the shuffle vector to get the final 1946 // use-list sorting indices before shuffling. 1947 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 1948 currentOrder, [&](auto item) { return shuffle[item.first]; })); 1949 value.shuffleUseList(shuffle); 1950 return success(); 1951 } 1952 1953 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 1954 // Precompute operation IDs according to the pre-order walk of the IR. We 1955 // can't do this while parsing since parseRegions ordering is not strictly 1956 // equal to the pre-order walk. 1957 unsigned operationID = 0; 1958 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 1959 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 1960 1961 auto blockWalk = topLevelOp->walk([this](Block *block) { 1962 for (auto arg : block->getArguments()) 1963 if (failed(sortUseListOrder(arg))) 1964 return WalkResult::interrupt(); 1965 return WalkResult::advance(); 1966 }); 1967 1968 auto resultWalk = topLevelOp->walk([this](Operation *op) { 1969 for (auto result : op->getResults()) 1970 if (failed(sortUseListOrder(result))) 1971 return WalkResult::interrupt(); 1972 return WalkResult::advance(); 1973 }); 1974 1975 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 1976 } 1977 1978 //===----------------------------------------------------------------------===// 1979 // IR Section 1980 1981 LogicalResult 1982 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 1983 Block *block) { 1984 EncodingReader reader(sectionData, fileLoc); 1985 1986 // A stack of operation regions currently being read from the bytecode. 1987 std::vector<RegionReadState> regionStack; 1988 1989 // Parse the top-level block using a temporary module operation. 1990 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 1991 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 1992 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 1993 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 1994 if (failed(parseBlockHeader(reader, regionStack.back()))) 1995 return failure(); 1996 valueScopes.emplace_back(); 1997 valueScopes.back().push(regionStack.back()); 1998 1999 // Iteratively parse regions until everything has been resolved. 2000 while (!regionStack.empty()) 2001 if (failed(parseRegions(regionStack, regionStack.back()))) 2002 return failure(); 2003 if (!forwardRefOps.empty()) { 2004 return reader.emitError( 2005 "not all forward unresolved forward operand references"); 2006 } 2007 2008 // Sort use-lists according to what specified in bytecode. 2009 if (failed(processUseLists(*moduleOp))) 2010 return reader.emitError( 2011 "parsed use-list orders were invalid and could not be applied"); 2012 2013 // Resolve dialect version. 2014 for (const BytecodeDialect &byteCodeDialect : dialects) { 2015 // Parsing is complete, give an opportunity to each dialect to visit the 2016 // IR and perform upgrades. 2017 if (!byteCodeDialect.loadedVersion) 2018 continue; 2019 if (byteCodeDialect.interface && 2020 failed(byteCodeDialect.interface->upgradeFromVersion( 2021 *moduleOp, *byteCodeDialect.loadedVersion))) 2022 return failure(); 2023 } 2024 2025 // Verify that the parsed operations are valid. 2026 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2027 return failure(); 2028 2029 // Splice the parsed operations over to the provided top-level block. 2030 auto &parsedOps = moduleOp->getBody()->getOperations(); 2031 auto &destOps = block->getOperations(); 2032 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2033 return success(); 2034 } 2035 2036 LogicalResult 2037 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2038 RegionReadState &readState) { 2039 // Process regions, blocks, and operations until the end or if a nested 2040 // region is encountered. In this case we push a new state in regionStack and 2041 // return, the processing of the current region will resume afterward. 2042 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2043 // If the current block hasn't been setup yet, parse the header for this 2044 // region. The current block is already setup when this function was 2045 // interrupted to recurse down in a nested region and we resume the current 2046 // block after processing the nested region. 2047 if (readState.curBlock == Region::iterator()) { 2048 if (failed(parseRegion(readState))) 2049 return failure(); 2050 2051 // If the region is empty, there is nothing to more to do. 2052 if (readState.curRegion->empty()) 2053 continue; 2054 } 2055 2056 // Parse the blocks within the region. 2057 EncodingReader &reader = *readState.reader; 2058 do { 2059 while (readState.numOpsRemaining--) { 2060 // Read in the next operation. We don't read its regions directly, we 2061 // handle those afterwards as necessary. 2062 bool isIsolatedFromAbove = false; 2063 FailureOr<Operation *> op = 2064 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2065 if (failed(op)) 2066 return failure(); 2067 2068 // If the op has regions, add it to the stack for processing and return: 2069 // we stop the processing of the current region and resume it after the 2070 // inner one is completed. Unless LazyLoading is activated in which case 2071 // nested region parsing is delayed. 2072 if ((*op)->getNumRegions()) { 2073 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2074 2075 // Isolated regions are encoded as a section in version 2 and above. 2076 if (version >= 2 && isIsolatedFromAbove) { 2077 bytecode::Section::ID sectionID; 2078 ArrayRef<uint8_t> sectionData; 2079 if (failed(reader.parseSection(sectionID, sectionData))) 2080 return failure(); 2081 if (sectionID != bytecode::Section::kIR) 2082 return emitError(fileLoc, "expected IR section for region"); 2083 childState.owningReader = 2084 std::make_unique<EncodingReader>(sectionData, fileLoc); 2085 childState.reader = childState.owningReader.get(); 2086 } 2087 2088 if (lazyLoading) { 2089 // If the user has a callback set, they have the opportunity 2090 // to control lazyloading as we go. 2091 if (!lazyOpsCallback || !lazyOpsCallback(*op)) { 2092 lazyLoadableOps.push_back( 2093 std::make_pair(*op, std::move(childState))); 2094 lazyLoadableOpsMap.try_emplace(*op, 2095 std::prev(lazyLoadableOps.end())); 2096 continue; 2097 } 2098 } 2099 regionStack.push_back(std::move(childState)); 2100 2101 // If the op is isolated from above, push a new value scope. 2102 if (isIsolatedFromAbove) 2103 valueScopes.emplace_back(); 2104 return success(); 2105 } 2106 } 2107 2108 // Move to the next block of the region. 2109 if (++readState.curBlock == readState.curRegion->end()) 2110 break; 2111 if (failed(parseBlockHeader(reader, readState))) 2112 return failure(); 2113 } while (true); 2114 2115 // Reset the current block and any values reserved for this region. 2116 readState.curBlock = {}; 2117 valueScopes.back().pop(readState); 2118 } 2119 2120 // When the regions have been fully parsed, pop them off of the read stack. If 2121 // the regions were isolated from above, we also pop the last value scope. 2122 if (readState.isIsolatedFromAbove) { 2123 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2124 valueScopes.pop_back(); 2125 } 2126 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2127 regionStack.pop_back(); 2128 return success(); 2129 } 2130 2131 FailureOr<Operation *> 2132 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2133 RegionReadState &readState, 2134 bool &isIsolatedFromAbove) { 2135 // Parse the name of the operation. 2136 bool wasRegistered; 2137 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2138 if (failed(opName)) 2139 return failure(); 2140 2141 // Parse the operation mask, which indicates which components of the operation 2142 // are present. 2143 uint8_t opMask; 2144 if (failed(reader.parseByte(opMask))) 2145 return failure(); 2146 2147 /// Parse the location. 2148 LocationAttr opLoc; 2149 if (failed(parseAttribute(reader, opLoc))) 2150 return failure(); 2151 2152 // With the location and name resolved, we can start building the operation 2153 // state. 2154 OperationState opState(opLoc, *opName); 2155 2156 // Parse the attributes of the operation. 2157 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2158 DictionaryAttr dictAttr; 2159 if (failed(parseAttribute(reader, dictAttr))) 2160 return failure(); 2161 opState.attributes = dictAttr; 2162 } 2163 2164 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2165 if (wasRegistered) { 2166 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2167 reader); 2168 if (failed( 2169 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2170 return failure(); 2171 } else { 2172 // If the operation wasn't registered when it was emitted, the properties 2173 // was serialized as an attribute. 2174 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2175 return failure(); 2176 } 2177 } 2178 2179 /// Parse the results of the operation. 2180 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2181 uint64_t numResults; 2182 if (failed(reader.parseVarInt(numResults))) 2183 return failure(); 2184 opState.types.resize(numResults); 2185 for (int i = 0, e = numResults; i < e; ++i) 2186 if (failed(parseType(reader, opState.types[i]))) 2187 return failure(); 2188 } 2189 2190 /// Parse the operands of the operation. 2191 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2192 uint64_t numOperands; 2193 if (failed(reader.parseVarInt(numOperands))) 2194 return failure(); 2195 opState.operands.resize(numOperands); 2196 for (int i = 0, e = numOperands; i < e; ++i) 2197 if (!(opState.operands[i] = parseOperand(reader))) 2198 return failure(); 2199 } 2200 2201 /// Parse the successors of the operation. 2202 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2203 uint64_t numSuccs; 2204 if (failed(reader.parseVarInt(numSuccs))) 2205 return failure(); 2206 opState.successors.resize(numSuccs); 2207 for (int i = 0, e = numSuccs; i < e; ++i) { 2208 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2209 "successor"))) 2210 return failure(); 2211 } 2212 } 2213 2214 /// Parse the use-list orders for the results of the operation. Use-list 2215 /// orders are available since version 3 of the bytecode. 2216 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2217 if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2218 size_t numResults = opState.types.size(); 2219 auto parseResult = parseUseListOrderForRange(reader, numResults); 2220 if (failed(parseResult)) 2221 return failure(); 2222 resultIdxToUseListMap = std::move(*parseResult); 2223 } 2224 2225 /// Parse the regions of the operation. 2226 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2227 uint64_t numRegions; 2228 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2229 return failure(); 2230 2231 opState.regions.reserve(numRegions); 2232 for (int i = 0, e = numRegions; i < e; ++i) 2233 opState.regions.push_back(std::make_unique<Region>()); 2234 } 2235 2236 // Create the operation at the back of the current block. 2237 Operation *op = Operation::create(opState); 2238 readState.curBlock->push_back(op); 2239 2240 // If the operation had results, update the value references. 2241 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2242 return failure(); 2243 2244 /// Store a map for every value that received a custom use-list order from the 2245 /// bytecode file. 2246 if (resultIdxToUseListMap.has_value()) { 2247 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2248 if (resultIdxToUseListMap->contains(idx)) { 2249 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2250 resultIdxToUseListMap->at(idx)); 2251 } 2252 } 2253 } 2254 return op; 2255 } 2256 2257 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2258 EncodingReader &reader = *readState.reader; 2259 2260 // Parse the number of blocks in the region. 2261 uint64_t numBlocks; 2262 if (failed(reader.parseVarInt(numBlocks))) 2263 return failure(); 2264 2265 // If the region is empty, there is nothing else to do. 2266 if (numBlocks == 0) 2267 return success(); 2268 2269 // Parse the number of values defined in this region. 2270 uint64_t numValues; 2271 if (failed(reader.parseVarInt(numValues))) 2272 return failure(); 2273 readState.numValues = numValues; 2274 2275 // Create the blocks within this region. We do this before processing so that 2276 // we can rely on the blocks existing when creating operations. 2277 readState.curBlocks.clear(); 2278 readState.curBlocks.reserve(numBlocks); 2279 for (uint64_t i = 0; i < numBlocks; ++i) { 2280 readState.curBlocks.push_back(new Block()); 2281 readState.curRegion->push_back(readState.curBlocks.back()); 2282 } 2283 2284 // Prepare the current value scope for this region. 2285 valueScopes.back().push(readState); 2286 2287 // Parse the entry block of the region. 2288 readState.curBlock = readState.curRegion->begin(); 2289 return parseBlockHeader(reader, readState); 2290 } 2291 2292 LogicalResult 2293 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2294 RegionReadState &readState) { 2295 bool hasArgs; 2296 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2297 return failure(); 2298 2299 // Parse the arguments of the block. 2300 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2301 return failure(); 2302 2303 // Uselist orders are available since version 3 of the bytecode. 2304 if (version < 3) 2305 return success(); 2306 2307 uint8_t hasUseListOrders = 0; 2308 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2309 return failure(); 2310 2311 if (!hasUseListOrders) 2312 return success(); 2313 2314 Block &blk = *readState.curBlock; 2315 auto argIdxToUseListMap = 2316 parseUseListOrderForRange(reader, blk.getNumArguments()); 2317 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2318 return failure(); 2319 2320 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2321 if (argIdxToUseListMap->contains(idx)) 2322 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2323 argIdxToUseListMap->at(idx)); 2324 2325 // We don't parse the operations of the block here, that's done elsewhere. 2326 return success(); 2327 } 2328 2329 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2330 Block *block) { 2331 // Parse the value ID for the first argument, and the number of arguments. 2332 uint64_t numArgs; 2333 if (failed(reader.parseVarInt(numArgs))) 2334 return failure(); 2335 2336 SmallVector<Type> argTypes; 2337 SmallVector<Location> argLocs; 2338 argTypes.reserve(numArgs); 2339 argLocs.reserve(numArgs); 2340 2341 Location unknownLoc = UnknownLoc::get(config.getContext()); 2342 while (numArgs--) { 2343 Type argType; 2344 LocationAttr argLoc = unknownLoc; 2345 if (version > 3) { 2346 // Parse the type with hasLoc flag to determine if it has type. 2347 uint64_t typeIdx; 2348 bool hasLoc; 2349 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2350 !(argType = attrTypeReader.resolveType(typeIdx))) 2351 return failure(); 2352 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2353 return failure(); 2354 } else { 2355 // All args has type and location. 2356 if (failed(parseType(reader, argType)) || 2357 failed(parseAttribute(reader, argLoc))) 2358 return failure(); 2359 } 2360 argTypes.push_back(argType); 2361 argLocs.push_back(argLoc); 2362 } 2363 block->addArguments(argTypes, argLocs); 2364 return defineValues(reader, block->getArguments()); 2365 } 2366 2367 //===----------------------------------------------------------------------===// 2368 // Value Processing 2369 2370 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2371 std::vector<Value> &values = valueScopes.back().values; 2372 Value *value = nullptr; 2373 if (failed(parseEntry(reader, values, value, "value"))) 2374 return Value(); 2375 2376 // Create a new forward reference if necessary. 2377 if (!*value) 2378 *value = createForwardRef(); 2379 return *value; 2380 } 2381 2382 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2383 ValueRange newValues) { 2384 ValueScope &valueScope = valueScopes.back(); 2385 std::vector<Value> &values = valueScope.values; 2386 2387 unsigned &valueID = valueScope.nextValueIDs.back(); 2388 unsigned valueIDEnd = valueID + newValues.size(); 2389 if (valueIDEnd > values.size()) { 2390 return reader.emitError( 2391 "value index range was outside of the expected range for " 2392 "the parent region, got [", 2393 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2394 values.size() - 1); 2395 } 2396 2397 // Assign the values and update any forward references. 2398 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2399 Value newValue = newValues[i]; 2400 2401 // Check to see if a definition for this value already exists. 2402 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2403 Operation *forwardRefOp = oldValue.getDefiningOp(); 2404 2405 // Assert that this is a forward reference operation. Given how we compute 2406 // definition ids (incrementally as we parse), it shouldn't be possible 2407 // for the value to be defined any other way. 2408 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2409 "value index was already defined?"); 2410 2411 oldValue.replaceAllUsesWith(newValue); 2412 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2413 } 2414 } 2415 return success(); 2416 } 2417 2418 Value BytecodeReader::Impl::createForwardRef() { 2419 // Check for an avaliable existing operation to use. Otherwise, create a new 2420 // fake operation to use for the reference. 2421 if (!openForwardRefOps.empty()) { 2422 Operation *op = &openForwardRefOps.back(); 2423 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2424 } else { 2425 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2426 } 2427 return forwardRefOps.back().getResult(0); 2428 } 2429 2430 //===----------------------------------------------------------------------===// 2431 // Entry Points 2432 //===----------------------------------------------------------------------===// 2433 2434 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2435 2436 BytecodeReader::BytecodeReader( 2437 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2438 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2439 Location sourceFileLoc = 2440 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2441 /*line=*/0, /*column=*/0); 2442 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2443 bufferOwnerRef); 2444 } 2445 2446 LogicalResult BytecodeReader::readTopLevel( 2447 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2448 return impl->read(block, lazyOpsCallback); 2449 } 2450 2451 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2452 return impl->getNumOpsToMaterialize(); 2453 } 2454 2455 bool BytecodeReader::isMaterializable(Operation *op) { 2456 return impl->isMaterializable(op); 2457 } 2458 2459 LogicalResult BytecodeReader::materialize( 2460 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2461 return impl->materialize(op, lazyOpsCallback); 2462 } 2463 2464 LogicalResult 2465 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2466 return impl->finalize(shouldMaterialize); 2467 } 2468 2469 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2470 return buffer.getBuffer().startswith("ML\xefR"); 2471 } 2472 2473 /// Read the bytecode from the provided memory buffer reference. 2474 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2475 /// and may be used to extend the lifetime of the buffer. 2476 static LogicalResult 2477 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2478 const ParserConfig &config, 2479 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2480 Location sourceFileLoc = 2481 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2482 /*line=*/0, /*column=*/0); 2483 if (!isBytecode(buffer)) { 2484 return emitError(sourceFileLoc, 2485 "input buffer is not an MLIR bytecode file"); 2486 } 2487 2488 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2489 buffer, bufferOwnerRef); 2490 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2491 } 2492 2493 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2494 const ParserConfig &config) { 2495 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2496 } 2497 LogicalResult 2498 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2499 Block *block, const ParserConfig &config) { 2500 return readBytecodeFileImpl( 2501 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2502 sourceMgr); 2503 } 2504