1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeReader.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/Verifier.h" 19 #include "mlir/IR/Visitors.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/MapVector.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/Support/Endian.h" 29 #include "llvm/Support/MemoryBufferRef.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/SourceMgr.h" 32 #include <cstddef> 33 #include <list> 34 #include <memory> 35 #include <numeric> 36 #include <optional> 37 38 #define DEBUG_TYPE "mlir-bytecode-reader" 39 40 using namespace mlir; 41 42 /// Stringify the given section ID. 43 static std::string toString(bytecode::Section::ID sectionID) { 44 switch (sectionID) { 45 case bytecode::Section::kString: 46 return "String (0)"; 47 case bytecode::Section::kDialect: 48 return "Dialect (1)"; 49 case bytecode::Section::kAttrType: 50 return "AttrType (2)"; 51 case bytecode::Section::kAttrTypeOffset: 52 return "AttrTypeOffset (3)"; 53 case bytecode::Section::kIR: 54 return "IR (4)"; 55 case bytecode::Section::kResource: 56 return "Resource (5)"; 57 case bytecode::Section::kResourceOffset: 58 return "ResourceOffset (6)"; 59 case bytecode::Section::kDialectVersions: 60 return "DialectVersions (7)"; 61 case bytecode::Section::kProperties: 62 return "Properties (8)"; 63 default: 64 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 65 } 66 } 67 68 /// Returns true if the given top-level section ID is optional. 69 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 70 switch (sectionID) { 71 case bytecode::Section::kString: 72 case bytecode::Section::kDialect: 73 case bytecode::Section::kAttrType: 74 case bytecode::Section::kAttrTypeOffset: 75 case bytecode::Section::kIR: 76 return false; 77 case bytecode::Section::kResource: 78 case bytecode::Section::kResourceOffset: 79 case bytecode::Section::kDialectVersions: 80 return true; 81 case bytecode::Section::kProperties: 82 return version < bytecode::kNativePropertiesEncoding; 83 default: 84 llvm_unreachable("unknown section ID"); 85 } 86 } 87 88 //===----------------------------------------------------------------------===// 89 // EncodingReader 90 //===----------------------------------------------------------------------===// 91 92 namespace { 93 class EncodingReader { 94 public: 95 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 96 : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} 97 explicit EncodingReader(StringRef contents, Location fileLoc) 98 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 99 contents.size()}, 100 fileLoc) {} 101 102 /// Returns true if the entire section has been read. 103 bool empty() const { return dataIt == dataEnd; } 104 105 /// Returns the remaining size of the bytecode. 106 size_t size() const { return dataEnd - dataIt; } 107 108 /// Align the current reader position to the specified alignment. 109 LogicalResult alignTo(unsigned alignment) { 110 if (!llvm::isPowerOf2_32(alignment)) 111 return emitError("expected alignment to be a power-of-two"); 112 113 // Shift the reader position to the next alignment boundary. 114 while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { 115 uint8_t padding; 116 if (failed(parseByte(padding))) 117 return failure(); 118 if (padding != bytecode::kAlignmentByte) { 119 return emitError("expected alignment byte (0xCB), but got: '0x" + 120 llvm::utohexstr(padding) + "'"); 121 } 122 } 123 124 // Ensure the data iterator is now aligned. This case is unlikely because we 125 // *just* went through the effort to align the data iterator. 126 if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) { 127 return emitError("expected data iterator aligned to ", alignment, 128 ", but got pointer: '0x" + 129 llvm::utohexstr((uintptr_t)dataIt) + "'"); 130 } 131 132 return success(); 133 } 134 135 /// Emit an error using the given arguments. 136 template <typename... Args> 137 InFlightDiagnostic emitError(Args &&...args) const { 138 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 139 } 140 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 141 142 /// Parse a single byte from the stream. 143 template <typename T> 144 LogicalResult parseByte(T &value) { 145 if (empty()) 146 return emitError("attempting to parse a byte at the end of the bytecode"); 147 value = static_cast<T>(*dataIt++); 148 return success(); 149 } 150 /// Parse a range of bytes of 'length' into the given result. 151 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 152 if (length > size()) { 153 return emitError("attempting to parse ", length, " bytes when only ", 154 size(), " remain"); 155 } 156 result = {dataIt, length}; 157 dataIt += length; 158 return success(); 159 } 160 /// Parse a range of bytes of 'length' into the given result, which can be 161 /// assumed to be large enough to hold `length`. 162 LogicalResult parseBytes(size_t length, uint8_t *result) { 163 if (length > size()) { 164 return emitError("attempting to parse ", length, " bytes when only ", 165 size(), " remain"); 166 } 167 memcpy(result, dataIt, length); 168 dataIt += length; 169 return success(); 170 } 171 172 /// Parse an aligned blob of data, where the alignment was encoded alongside 173 /// the data. 174 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 175 uint64_t &alignment) { 176 uint64_t dataSize; 177 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 178 failed(alignTo(alignment))) 179 return failure(); 180 return parseBytes(dataSize, data); 181 } 182 183 /// Parse a variable length encoded integer from the byte stream. The first 184 /// encoded byte contains a prefix in the low bits indicating the encoded 185 /// length of the value. This length prefix is a bit sequence of '0's followed 186 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 187 /// (not including the prefix byte). All remaining bits in the first byte, 188 /// along with all of the bits in additional bytes, provide the value of the 189 /// integer encoded in little-endian order. 190 LogicalResult parseVarInt(uint64_t &result) { 191 // Parse the first byte of the encoding, which contains the length prefix. 192 if (failed(parseByte(result))) 193 return failure(); 194 195 // Handle the overwhelmingly common case where the value is stored in a 196 // single byte. In this case, the first bit is the `1` marker bit. 197 if (LLVM_LIKELY(result & 1)) { 198 result >>= 1; 199 return success(); 200 } 201 202 // Handle the overwhelming uncommon case where the value required all 8 203 // bytes (i.e. a really really big number). In this case, the marker byte is 204 // all zeros: `00000000`. 205 if (LLVM_UNLIKELY(result == 0)) { 206 llvm::support::ulittle64_t resultLE; 207 if (failed(parseBytes(sizeof(resultLE), 208 reinterpret_cast<uint8_t *>(&resultLE)))) 209 return failure(); 210 result = resultLE; 211 return success(); 212 } 213 return parseMultiByteVarInt(result); 214 } 215 216 /// Parse a signed variable length encoded integer from the byte stream. A 217 /// signed varint is encoded as a normal varint with zigzag encoding applied, 218 /// i.e. the low bit of the value is used to indicate the sign. 219 LogicalResult parseSignedVarInt(uint64_t &result) { 220 if (failed(parseVarInt(result))) 221 return failure(); 222 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 223 result = (result >> 1) ^ (~(result & 1) + 1); 224 return success(); 225 } 226 227 /// Parse a variable length encoded integer whose low bit is used to encode an 228 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 229 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 230 if (failed(parseVarInt(result))) 231 return failure(); 232 flag = result & 1; 233 result >>= 1; 234 return success(); 235 } 236 237 /// Skip the first `length` bytes within the reader. 238 LogicalResult skipBytes(size_t length) { 239 if (length > size()) { 240 return emitError("attempting to skip ", length, " bytes when only ", 241 size(), " remain"); 242 } 243 dataIt += length; 244 return success(); 245 } 246 247 /// Parse a null-terminated string into `result` (without including the NUL 248 /// terminator). 249 LogicalResult parseNullTerminatedString(StringRef &result) { 250 const char *startIt = (const char *)dataIt; 251 const char *nulIt = (const char *)memchr(startIt, 0, size()); 252 if (!nulIt) 253 return emitError( 254 "malformed null-terminated string, no null character found"); 255 256 result = StringRef(startIt, nulIt - startIt); 257 dataIt = (const uint8_t *)nulIt + 1; 258 return success(); 259 } 260 261 /// Parse a section header, placing the kind of section in `sectionID` and the 262 /// contents of the section in `sectionData`. 263 LogicalResult parseSection(bytecode::Section::ID §ionID, 264 ArrayRef<uint8_t> §ionData) { 265 uint8_t sectionIDAndHasAlignment; 266 uint64_t length; 267 if (failed(parseByte(sectionIDAndHasAlignment)) || 268 failed(parseVarInt(length))) 269 return failure(); 270 271 // Extract the section ID and whether the section is aligned. The high bit 272 // of the ID is the alignment flag. 273 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 274 0b01111111); 275 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 276 277 // Check that the section is actually valid before trying to process its 278 // data. 279 if (sectionID >= bytecode::Section::kNumSections) 280 return emitError("invalid section ID: ", unsigned(sectionID)); 281 282 // Process the section alignment if present. 283 if (hasAlignment) { 284 uint64_t alignment; 285 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 286 return failure(); 287 } 288 289 // Parse the actual section data. 290 return parseBytes(static_cast<size_t>(length), sectionData); 291 } 292 293 Location getLoc() const { return fileLoc; } 294 295 private: 296 /// Parse a variable length encoded integer from the byte stream. This method 297 /// is a fallback when the number of bytes used to encode the value is greater 298 /// than 1, but less than the max (9). The provided `result` value can be 299 /// assumed to already contain the first byte of the value. 300 /// NOTE: This method is marked noinline to avoid pessimizing the common case 301 /// of single byte encoding. 302 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 303 // Count the number of trailing zeros in the marker byte, this indicates the 304 // number of trailing bytes that are part of the value. We use `uint32_t` 305 // here because we only care about the first byte, and so that be actually 306 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 307 // implementation). 308 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 309 assert(numBytes > 0 && numBytes <= 7 && 310 "unexpected number of trailing zeros in varint encoding"); 311 312 // Parse in the remaining bytes of the value. 313 llvm::support::ulittle64_t resultLE(result); 314 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1))) 315 return failure(); 316 317 // Shift out the low-order bits that were used to mark how the value was 318 // encoded. 319 result = resultLE >> (numBytes + 1); 320 return success(); 321 } 322 323 /// The current data iterator, and an iterator to the end of the buffer. 324 const uint8_t *dataIt, *dataEnd; 325 326 /// A location for the bytecode used to report errors. 327 Location fileLoc; 328 }; 329 } // namespace 330 331 /// Resolve an index into the given entry list. `entry` may either be a 332 /// reference, in which case it is assigned to the corresponding value in 333 /// `entries`, or a pointer, in which case it is assigned to the address of the 334 /// element in `entries`. 335 template <typename RangeT, typename T> 336 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 337 uint64_t index, T &entry, 338 StringRef entryStr) { 339 if (index >= entries.size()) 340 return reader.emitError("invalid ", entryStr, " index: ", index); 341 342 // If the provided entry is a pointer, resolve to the address of the entry. 343 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 344 entry = entries[index]; 345 else 346 entry = &entries[index]; 347 return success(); 348 } 349 350 /// Parse and resolve an index into the given entry list. 351 template <typename RangeT, typename T> 352 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 353 T &entry, StringRef entryStr) { 354 uint64_t entryIdx; 355 if (failed(reader.parseVarInt(entryIdx))) 356 return failure(); 357 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // StringSectionReader 362 //===----------------------------------------------------------------------===// 363 364 namespace { 365 /// This class is used to read references to the string section from the 366 /// bytecode. 367 class StringSectionReader { 368 public: 369 /// Initialize the string section reader with the given section data. 370 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 371 372 /// Parse a shared string from the string section. The shared string is 373 /// encoded using an index to a corresponding string in the string section. 374 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 375 return parseEntry(reader, strings, result, "string"); 376 } 377 378 /// Parse a shared string from the string section. The shared string is 379 /// encoded using an index to a corresponding string in the string section. 380 /// This variant parses a flag compressed with the index. 381 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 382 bool &flag) { 383 uint64_t entryIdx; 384 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 385 return failure(); 386 return parseStringAtIndex(reader, entryIdx, result); 387 } 388 389 /// Parse a shared string from the string section. The shared string is 390 /// encoded using an index to a corresponding string in the string section. 391 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 392 StringRef &result) { 393 return resolveEntry(reader, strings, index, result, "string"); 394 } 395 396 private: 397 /// The table of strings referenced within the bytecode file. 398 SmallVector<StringRef> strings; 399 }; 400 } // namespace 401 402 LogicalResult StringSectionReader::initialize(Location fileLoc, 403 ArrayRef<uint8_t> sectionData) { 404 EncodingReader stringReader(sectionData, fileLoc); 405 406 // Parse the number of strings in the section. 407 uint64_t numStrings; 408 if (failed(stringReader.parseVarInt(numStrings))) 409 return failure(); 410 strings.resize(numStrings); 411 412 // Parse each of the strings. The sizes of the strings are encoded in reverse 413 // order, so that's the order we populate the table. 414 size_t stringDataEndOffset = sectionData.size(); 415 for (StringRef &string : llvm::reverse(strings)) { 416 uint64_t stringSize; 417 if (failed(stringReader.parseVarInt(stringSize))) 418 return failure(); 419 if (stringDataEndOffset < stringSize) { 420 return stringReader.emitError( 421 "string size exceeds the available data size"); 422 } 423 424 // Extract the string from the data, dropping the null character. 425 size_t stringOffset = stringDataEndOffset - stringSize; 426 string = StringRef( 427 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 428 stringSize - 1); 429 stringDataEndOffset = stringOffset; 430 } 431 432 // Check that the only remaining data was for the strings, i.e. the reader 433 // should be at the same offset as the first string. 434 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 435 return stringReader.emitError("unexpected trailing data between the " 436 "offsets for strings and their data"); 437 } 438 return success(); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // BytecodeDialect 443 //===----------------------------------------------------------------------===// 444 445 namespace { 446 class DialectReader; 447 448 /// This struct represents a dialect entry within the bytecode. 449 struct BytecodeDialect { 450 /// Load the dialect into the provided context if it hasn't been loaded yet. 451 /// Returns failure if the dialect couldn't be loaded *and* the provided 452 /// context does not allow unregistered dialects. The provided reader is used 453 /// for error emission if necessary. 454 LogicalResult load(DialectReader &reader, MLIRContext *ctx); 455 456 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 457 /// only be called after `load`. 458 Dialect *getLoadedDialect() const { 459 assert(dialect && 460 "expected `load` to be invoked before `getLoadedDialect`"); 461 return *dialect; 462 } 463 464 /// The loaded dialect entry. This field is std::nullopt if we haven't 465 /// attempted to load, nullptr if we failed to load, otherwise the loaded 466 /// dialect. 467 std::optional<Dialect *> dialect; 468 469 /// The bytecode interface of the dialect, or nullptr if the dialect does not 470 /// implement the bytecode interface. This field should only be checked if the 471 /// `dialect` field is not std::nullopt. 472 const BytecodeDialectInterface *interface = nullptr; 473 474 /// The name of the dialect. 475 StringRef name; 476 477 /// A buffer containing the encoding of the dialect version parsed. 478 ArrayRef<uint8_t> versionBuffer; 479 480 /// Lazy loaded dialect version from the handle above. 481 std::unique_ptr<DialectVersion> loadedVersion; 482 }; 483 484 /// This struct represents an operation name entry within the bytecode. 485 struct BytecodeOperationName { 486 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 487 std::optional<bool> wasRegistered) 488 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 489 490 /// The loaded operation name, or std::nullopt if it hasn't been processed 491 /// yet. 492 std::optional<OperationName> opName; 493 494 /// The dialect that owns this operation name. 495 BytecodeDialect *dialect; 496 497 /// The name of the operation, without the dialect prefix. 498 StringRef name; 499 500 /// Whether this operation was registered when the bytecode was produced. 501 /// This flag is populated when bytecode version >=kNativePropertiesEncoding. 502 std::optional<bool> wasRegistered; 503 }; 504 } // namespace 505 506 /// Parse a single dialect group encoded in the byte stream. 507 static LogicalResult parseDialectGrouping( 508 EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, 509 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 510 // Parse the dialect and the number of entries in the group. 511 BytecodeDialect *dialect; 512 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 513 return failure(); 514 uint64_t numEntries; 515 if (failed(reader.parseVarInt(numEntries))) 516 return failure(); 517 518 for (uint64_t i = 0; i < numEntries; ++i) 519 if (failed(entryCallback(dialect))) 520 return failure(); 521 return success(); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // ResourceSectionReader 526 //===----------------------------------------------------------------------===// 527 528 namespace { 529 /// This class is used to read the resource section from the bytecode. 530 class ResourceSectionReader { 531 public: 532 /// Initialize the resource section reader with the given section data. 533 LogicalResult 534 initialize(Location fileLoc, const ParserConfig &config, 535 MutableArrayRef<BytecodeDialect> dialects, 536 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 537 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 538 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 539 540 /// Parse a dialect resource handle from the resource section. 541 LogicalResult parseResourceHandle(EncodingReader &reader, 542 AsmDialectResourceHandle &result) { 543 return parseEntry(reader, dialectResources, result, "resource handle"); 544 } 545 546 private: 547 /// The table of dialect resources within the bytecode file. 548 SmallVector<AsmDialectResourceHandle> dialectResources; 549 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 550 }; 551 552 class ParsedResourceEntry : public AsmParsedResourceEntry { 553 public: 554 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 555 EncodingReader &reader, StringSectionReader &stringReader, 556 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 557 : key(key), kind(kind), reader(reader), stringReader(stringReader), 558 bufferOwnerRef(bufferOwnerRef) {} 559 ~ParsedResourceEntry() override = default; 560 561 StringRef getKey() const final { return key; } 562 563 InFlightDiagnostic emitError() const final { return reader.emitError(); } 564 565 AsmResourceEntryKind getKind() const final { return kind; } 566 567 FailureOr<bool> parseAsBool() const final { 568 if (kind != AsmResourceEntryKind::Bool) 569 return emitError() << "expected a bool resource entry, but found a " 570 << toString(kind) << " entry instead"; 571 572 bool value; 573 if (failed(reader.parseByte(value))) 574 return failure(); 575 return value; 576 } 577 FailureOr<std::string> parseAsString() const final { 578 if (kind != AsmResourceEntryKind::String) 579 return emitError() << "expected a string resource entry, but found a " 580 << toString(kind) << " entry instead"; 581 582 StringRef string; 583 if (failed(stringReader.parseString(reader, string))) 584 return failure(); 585 return string.str(); 586 } 587 588 FailureOr<AsmResourceBlob> 589 parseAsBlob(BlobAllocatorFn allocator) const final { 590 if (kind != AsmResourceEntryKind::Blob) 591 return emitError() << "expected a blob resource entry, but found a " 592 << toString(kind) << " entry instead"; 593 594 ArrayRef<uint8_t> data; 595 uint64_t alignment; 596 if (failed(reader.parseBlobAndAlignment(data, alignment))) 597 return failure(); 598 599 // If we have an extendable reference to the buffer owner, we don't need to 600 // allocate a new buffer for the data, and can use the data directly. 601 if (bufferOwnerRef) { 602 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 603 data.size()); 604 605 // Allocate an unmanager buffer which captures a reference to the owner. 606 // For now we just mark this as immutable, but in the future we should 607 // explore marking this as mutable when desired. 608 return UnmanagedAsmResourceBlob::allocateWithAlign( 609 charData, alignment, 610 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 611 } 612 613 // Allocate memory for the blob using the provided allocator and copy the 614 // data into it. 615 AsmResourceBlob blob = allocator(data.size(), alignment); 616 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 617 blob.isMutable() && 618 "blob allocator did not return a properly aligned address"); 619 memcpy(blob.getMutableData().data(), data.data(), data.size()); 620 return blob; 621 } 622 623 private: 624 StringRef key; 625 AsmResourceEntryKind kind; 626 EncodingReader &reader; 627 StringSectionReader &stringReader; 628 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 629 }; 630 } // namespace 631 632 template <typename T> 633 static LogicalResult 634 parseResourceGroup(Location fileLoc, bool allowEmpty, 635 EncodingReader &offsetReader, EncodingReader &resourceReader, 636 StringSectionReader &stringReader, T *handler, 637 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 638 function_ref<StringRef(StringRef)> remapKey = {}, 639 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 640 uint64_t numResources; 641 if (failed(offsetReader.parseVarInt(numResources))) 642 return failure(); 643 644 for (uint64_t i = 0; i < numResources; ++i) { 645 StringRef key; 646 AsmResourceEntryKind kind; 647 uint64_t resourceOffset; 648 ArrayRef<uint8_t> data; 649 if (failed(stringReader.parseString(offsetReader, key)) || 650 failed(offsetReader.parseVarInt(resourceOffset)) || 651 failed(offsetReader.parseByte(kind)) || 652 failed(resourceReader.parseBytes(resourceOffset, data))) 653 return failure(); 654 655 // Process the resource key. 656 if ((processKeyFn && failed(processKeyFn(key)))) 657 return failure(); 658 659 // If the resource data is empty and we allow it, don't error out when 660 // parsing below, just skip it. 661 if (allowEmpty && data.empty()) 662 continue; 663 664 // Ignore the entry if we don't have a valid handler. 665 if (!handler) 666 continue; 667 668 // Otherwise, parse the resource value. 669 EncodingReader entryReader(data, fileLoc); 670 key = remapKey(key); 671 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 672 bufferOwnerRef); 673 if (failed(handler->parseResource(entry))) 674 return failure(); 675 if (!entryReader.empty()) { 676 return entryReader.emitError( 677 "unexpected trailing bytes in resource entry '", key, "'"); 678 } 679 } 680 return success(); 681 } 682 683 LogicalResult ResourceSectionReader::initialize( 684 Location fileLoc, const ParserConfig &config, 685 MutableArrayRef<BytecodeDialect> dialects, 686 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 687 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 688 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 689 EncodingReader resourceReader(sectionData, fileLoc); 690 EncodingReader offsetReader(offsetSectionData, fileLoc); 691 692 // Read the number of external resource providers. 693 uint64_t numExternalResourceGroups; 694 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 695 return failure(); 696 697 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 698 // provides most of the arguments. 699 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 700 function_ref<LogicalResult(StringRef)> keyFn = {}) { 701 auto resolveKey = [&](StringRef key) -> StringRef { 702 auto it = dialectResourceHandleRenamingMap.find(key); 703 if (it == dialectResourceHandleRenamingMap.end()) 704 return ""; 705 return it->second; 706 }; 707 708 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 709 stringReader, handler, bufferOwnerRef, resolveKey, 710 keyFn); 711 }; 712 713 // Read the external resources from the bytecode. 714 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 715 StringRef key; 716 if (failed(stringReader.parseString(offsetReader, key))) 717 return failure(); 718 719 // Get the handler for these resources. 720 // TODO: Should we require handling external resources in some scenarios? 721 AsmResourceParser *handler = config.getResourceParser(key); 722 if (!handler) { 723 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 724 << "'"; 725 } 726 727 if (failed(parseGroup(handler))) 728 return failure(); 729 } 730 731 // Read the dialect resources from the bytecode. 732 MLIRContext *ctx = fileLoc->getContext(); 733 while (!offsetReader.empty()) { 734 BytecodeDialect *dialect; 735 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 736 failed(dialect->load(dialectReader, ctx))) 737 return failure(); 738 Dialect *loadedDialect = dialect->getLoadedDialect(); 739 if (!loadedDialect) { 740 return resourceReader.emitError() 741 << "dialect '" << dialect->name << "' is unknown"; 742 } 743 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 744 if (!handler) { 745 return resourceReader.emitError() 746 << "unexpected resources for dialect '" << dialect->name << "'"; 747 } 748 749 // Ensure that each resource is declared before being processed. 750 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 751 FailureOr<AsmDialectResourceHandle> handle = 752 handler->declareResource(key); 753 if (failed(handle)) { 754 return resourceReader.emitError() 755 << "unknown 'resource' key '" << key << "' for dialect '" 756 << dialect->name << "'"; 757 } 758 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 759 dialectResources.push_back(*handle); 760 return success(); 761 }; 762 763 // Parse the resources for this dialect. We allow empty resources because we 764 // just treat these as declarations. 765 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 766 return failure(); 767 } 768 769 return success(); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // Attribute/Type Reader 774 //===----------------------------------------------------------------------===// 775 776 namespace { 777 /// This class provides support for reading attribute and type entries from the 778 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 779 /// this reader to manage when to actually parse them from the bytecode. 780 class AttrTypeReader { 781 /// This class represents a single attribute or type entry. 782 template <typename T> 783 struct Entry { 784 /// The entry, or null if it hasn't been resolved yet. 785 T entry = {}; 786 /// The parent dialect of this entry. 787 BytecodeDialect *dialect = nullptr; 788 /// A flag indicating if the entry was encoded using a custom encoding, 789 /// instead of using the textual assembly format. 790 bool hasCustomEncoding = false; 791 /// The raw data of this entry in the bytecode. 792 ArrayRef<uint8_t> data; 793 }; 794 using AttrEntry = Entry<Attribute>; 795 using TypeEntry = Entry<Type>; 796 797 public: 798 AttrTypeReader(StringSectionReader &stringReader, 799 ResourceSectionReader &resourceReader, Location fileLoc) 800 : stringReader(stringReader), resourceReader(resourceReader), 801 fileLoc(fileLoc) {} 802 803 /// Initialize the attribute and type information within the reader. 804 LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, 805 ArrayRef<uint8_t> sectionData, 806 ArrayRef<uint8_t> offsetSectionData); 807 808 /// Resolve the attribute or type at the given index. Returns nullptr on 809 /// failure. 810 Attribute resolveAttribute(size_t index) { 811 return resolveEntry(attributes, index, "Attribute"); 812 } 813 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 814 815 /// Parse a reference to an attribute or type using the given reader. 816 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 817 uint64_t attrIdx; 818 if (failed(reader.parseVarInt(attrIdx))) 819 return failure(); 820 result = resolveAttribute(attrIdx); 821 return success(!!result); 822 } 823 LogicalResult parseOptionalAttribute(EncodingReader &reader, 824 Attribute &result) { 825 uint64_t attrIdx; 826 bool flag; 827 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 828 return failure(); 829 if (!flag) 830 return success(); 831 result = resolveAttribute(attrIdx); 832 return success(!!result); 833 } 834 835 LogicalResult parseType(EncodingReader &reader, Type &result) { 836 uint64_t typeIdx; 837 if (failed(reader.parseVarInt(typeIdx))) 838 return failure(); 839 result = resolveType(typeIdx); 840 return success(!!result); 841 } 842 843 template <typename T> 844 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 845 Attribute baseResult; 846 if (failed(parseAttribute(reader, baseResult))) 847 return failure(); 848 if ((result = dyn_cast<T>(baseResult))) 849 return success(); 850 return reader.emitError("expected attribute of type: ", 851 llvm::getTypeName<T>(), ", but got: ", baseResult); 852 } 853 854 private: 855 /// Resolve the given entry at `index`. 856 template <typename T> 857 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 858 StringRef entryType); 859 860 /// Parse an entry using the given reader that was encoded using the textual 861 /// assembly format. 862 template <typename T> 863 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 864 StringRef entryType); 865 866 /// Parse an entry using the given reader that was encoded using a custom 867 /// bytecode format. 868 template <typename T> 869 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 870 StringRef entryType); 871 872 /// The string section reader used to resolve string references when parsing 873 /// custom encoded attribute/type entries. 874 StringSectionReader &stringReader; 875 876 /// The resource section reader used to resolve resource references when 877 /// parsing custom encoded attribute/type entries. 878 ResourceSectionReader &resourceReader; 879 880 /// The set of attribute and type entries. 881 SmallVector<AttrEntry> attributes; 882 SmallVector<TypeEntry> types; 883 884 /// A location used for error emission. 885 Location fileLoc; 886 }; 887 888 class DialectReader : public DialectBytecodeReader { 889 public: 890 DialectReader(AttrTypeReader &attrTypeReader, 891 StringSectionReader &stringReader, 892 ResourceSectionReader &resourceReader, EncodingReader &reader) 893 : attrTypeReader(attrTypeReader), stringReader(stringReader), 894 resourceReader(resourceReader), reader(reader) {} 895 896 InFlightDiagnostic emitError(const Twine &msg) override { 897 return reader.emitError(msg); 898 } 899 900 DialectReader withEncodingReader(EncodingReader &encReader) { 901 return DialectReader(attrTypeReader, stringReader, resourceReader, 902 encReader); 903 } 904 905 Location getLoc() const { return reader.getLoc(); } 906 907 //===--------------------------------------------------------------------===// 908 // IR 909 //===--------------------------------------------------------------------===// 910 911 LogicalResult readAttribute(Attribute &result) override { 912 return attrTypeReader.parseAttribute(reader, result); 913 } 914 LogicalResult readOptionalAttribute(Attribute &result) override { 915 return attrTypeReader.parseOptionalAttribute(reader, result); 916 } 917 LogicalResult readType(Type &result) override { 918 return attrTypeReader.parseType(reader, result); 919 } 920 921 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 922 AsmDialectResourceHandle handle; 923 if (failed(resourceReader.parseResourceHandle(reader, handle))) 924 return failure(); 925 return handle; 926 } 927 928 //===--------------------------------------------------------------------===// 929 // Primitives 930 //===--------------------------------------------------------------------===// 931 932 LogicalResult readVarInt(uint64_t &result) override { 933 return reader.parseVarInt(result); 934 } 935 936 LogicalResult readSignedVarInt(int64_t &result) override { 937 uint64_t unsignedResult; 938 if (failed(reader.parseSignedVarInt(unsignedResult))) 939 return failure(); 940 result = static_cast<int64_t>(unsignedResult); 941 return success(); 942 } 943 944 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 945 // Small values are encoded using a single byte. 946 if (bitWidth <= 8) { 947 uint8_t value; 948 if (failed(reader.parseByte(value))) 949 return failure(); 950 return APInt(bitWidth, value); 951 } 952 953 // Large values up to 64 bits are encoded using a single varint. 954 if (bitWidth <= 64) { 955 uint64_t value; 956 if (failed(reader.parseSignedVarInt(value))) 957 return failure(); 958 return APInt(bitWidth, value); 959 } 960 961 // Otherwise, for really big values we encode the array of active words in 962 // the value. 963 uint64_t numActiveWords; 964 if (failed(reader.parseVarInt(numActiveWords))) 965 return failure(); 966 SmallVector<uint64_t, 4> words(numActiveWords); 967 for (uint64_t i = 0; i < numActiveWords; ++i) 968 if (failed(reader.parseSignedVarInt(words[i]))) 969 return failure(); 970 return APInt(bitWidth, words); 971 } 972 973 FailureOr<APFloat> 974 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 975 FailureOr<APInt> intVal = 976 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 977 if (failed(intVal)) 978 return failure(); 979 return APFloat(semantics, *intVal); 980 } 981 982 LogicalResult readString(StringRef &result) override { 983 return stringReader.parseString(reader, result); 984 } 985 986 LogicalResult readBlob(ArrayRef<char> &result) override { 987 uint64_t dataSize; 988 ArrayRef<uint8_t> data; 989 if (failed(reader.parseVarInt(dataSize)) || 990 failed(reader.parseBytes(dataSize, data))) 991 return failure(); 992 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 993 data.size()); 994 return success(); 995 } 996 997 LogicalResult readBool(bool &result) override { 998 return reader.parseByte(result); 999 } 1000 1001 private: 1002 AttrTypeReader &attrTypeReader; 1003 StringSectionReader &stringReader; 1004 ResourceSectionReader &resourceReader; 1005 EncodingReader &reader; 1006 }; 1007 1008 /// Wraps the properties section and handles reading properties out of it. 1009 class PropertiesSectionReader { 1010 public: 1011 /// Initialize the properties section reader with the given section data. 1012 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1013 if (sectionData.empty()) 1014 return success(); 1015 EncodingReader propReader(sectionData, fileLoc); 1016 uint64_t count; 1017 if (failed(propReader.parseVarInt(count))) 1018 return failure(); 1019 // Parse the raw properties buffer. 1020 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1021 return failure(); 1022 1023 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1024 offsetTable.reserve(count); 1025 for (auto idx : llvm::seq<int64_t>(0, count)) { 1026 (void)idx; 1027 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1028 ArrayRef<uint8_t> rawProperties; 1029 uint64_t dataSize; 1030 if (failed(offsetsReader.parseVarInt(dataSize)) || 1031 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1032 return failure(); 1033 } 1034 if (!offsetsReader.empty()) 1035 return offsetsReader.emitError() 1036 << "Broken properties section: didn't exhaust the offsets table"; 1037 return success(); 1038 } 1039 1040 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1041 OperationName *opName, OperationState &opState) { 1042 uint64_t propertiesIdx; 1043 if (failed(dialectReader.readVarInt(propertiesIdx))) 1044 return failure(); 1045 if (propertiesIdx >= offsetTable.size()) 1046 return dialectReader.emitError("Properties idx out-of-bound for ") 1047 << opName->getStringRef(); 1048 size_t propertiesOffset = offsetTable[propertiesIdx]; 1049 if (propertiesIdx >= propertiesBuffers.size()) 1050 return dialectReader.emitError("Properties offset out-of-bound for ") 1051 << opName->getStringRef(); 1052 1053 // Acquire the sub-buffer that represent the requested properties. 1054 ArrayRef<char> rawProperties; 1055 { 1056 // "Seek" to the requested offset by getting a new reader with the right 1057 // sub-buffer. 1058 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1059 fileLoc); 1060 // Properties are stored as a sequence of {size + raw_data}. 1061 if (failed( 1062 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1063 return failure(); 1064 } 1065 // Setup a new reader to read from the `rawProperties` sub-buffer. 1066 EncodingReader reader( 1067 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1068 DialectReader propReader = dialectReader.withEncodingReader(reader); 1069 1070 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1071 if (iface) 1072 return iface->readProperties(propReader, opState); 1073 if (opName->isRegistered()) 1074 return propReader.emitError( 1075 "has properties but missing BytecodeOpInterface for ") 1076 << opName->getStringRef(); 1077 // Unregistered op are storing properties as an attribute. 1078 return propReader.readAttribute(opState.propertiesAttr); 1079 } 1080 1081 private: 1082 /// The properties buffer referenced within the bytecode file. 1083 ArrayRef<uint8_t> propertiesBuffers; 1084 1085 /// Table of offset in the buffer above. 1086 SmallVector<int64_t> offsetTable; 1087 }; 1088 } // namespace 1089 1090 LogicalResult 1091 AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, 1092 ArrayRef<uint8_t> sectionData, 1093 ArrayRef<uint8_t> offsetSectionData) { 1094 EncodingReader offsetReader(offsetSectionData, fileLoc); 1095 1096 // Parse the number of attribute and type entries. 1097 uint64_t numAttributes, numTypes; 1098 if (failed(offsetReader.parseVarInt(numAttributes)) || 1099 failed(offsetReader.parseVarInt(numTypes))) 1100 return failure(); 1101 attributes.resize(numAttributes); 1102 types.resize(numTypes); 1103 1104 // A functor used to accumulate the offsets for the entries in the given 1105 // range. 1106 uint64_t currentOffset = 0; 1107 auto parseEntries = [&](auto &&range) { 1108 size_t currentIndex = 0, endIndex = range.size(); 1109 1110 // Parse an individual entry. 1111 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1112 auto &entry = range[currentIndex++]; 1113 1114 uint64_t entrySize; 1115 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1116 entry.hasCustomEncoding))) 1117 return failure(); 1118 1119 // Verify that the offset is actually valid. 1120 if (currentOffset + entrySize > sectionData.size()) { 1121 return offsetReader.emitError( 1122 "Attribute or Type entry offset points past the end of section"); 1123 } 1124 1125 entry.data = sectionData.slice(currentOffset, entrySize); 1126 entry.dialect = dialect; 1127 currentOffset += entrySize; 1128 return success(); 1129 }; 1130 while (currentIndex != endIndex) 1131 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1132 return failure(); 1133 return success(); 1134 }; 1135 1136 // Process each of the attributes, and then the types. 1137 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1138 return failure(); 1139 1140 // Ensure that we read everything from the section. 1141 if (!offsetReader.empty()) { 1142 return offsetReader.emitError( 1143 "unexpected trailing data in the Attribute/Type offset section"); 1144 } 1145 return success(); 1146 } 1147 1148 template <typename T> 1149 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1150 StringRef entryType) { 1151 if (index >= entries.size()) { 1152 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1153 return {}; 1154 } 1155 1156 // If the entry has already been resolved, there is nothing left to do. 1157 Entry<T> &entry = entries[index]; 1158 if (entry.entry) 1159 return entry.entry; 1160 1161 // Parse the entry. 1162 EncodingReader reader(entry.data, fileLoc); 1163 1164 // Parse based on how the entry was encoded. 1165 if (entry.hasCustomEncoding) { 1166 if (failed(parseCustomEntry(entry, reader, entryType))) 1167 return T(); 1168 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1169 return T(); 1170 } 1171 1172 if (!reader.empty()) { 1173 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1174 return T(); 1175 } 1176 return entry.entry; 1177 } 1178 1179 template <typename T> 1180 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1181 StringRef entryType) { 1182 StringRef asmStr; 1183 if (failed(reader.parseNullTerminatedString(asmStr))) 1184 return failure(); 1185 1186 // Invoke the MLIR assembly parser to parse the entry text. 1187 size_t numRead = 0; 1188 MLIRContext *context = fileLoc->getContext(); 1189 if constexpr (std::is_same_v<T, Type>) 1190 result = 1191 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1192 else 1193 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1194 /*isKnownNullTerminated=*/true); 1195 if (!result) 1196 return failure(); 1197 1198 // Ensure there weren't dangling characters after the entry. 1199 if (numRead != asmStr.size()) { 1200 return reader.emitError("trailing characters found after ", entryType, 1201 " assembly format: ", asmStr.drop_front(numRead)); 1202 } 1203 return success(); 1204 } 1205 1206 template <typename T> 1207 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1208 EncodingReader &reader, 1209 StringRef entryType) { 1210 DialectReader dialectReader(*this, stringReader, resourceReader, reader); 1211 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1212 return failure(); 1213 // Ensure that the dialect implements the bytecode interface. 1214 if (!entry.dialect->interface) { 1215 return reader.emitError("dialect '", entry.dialect->name, 1216 "' does not implement the bytecode interface"); 1217 } 1218 1219 // Ask the dialect to parse the entry. If the dialect is versioned, parse 1220 // using the versioned encoding readers. 1221 if (entry.dialect->loadedVersion.get()) { 1222 if constexpr (std::is_same_v<T, Type>) 1223 entry.entry = entry.dialect->interface->readType( 1224 dialectReader, *entry.dialect->loadedVersion); 1225 else 1226 entry.entry = entry.dialect->interface->readAttribute( 1227 dialectReader, *entry.dialect->loadedVersion); 1228 1229 } else { 1230 if constexpr (std::is_same_v<T, Type>) 1231 entry.entry = entry.dialect->interface->readType(dialectReader); 1232 else 1233 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1234 } 1235 return success(!!entry.entry); 1236 } 1237 1238 //===----------------------------------------------------------------------===// 1239 // Bytecode Reader 1240 //===----------------------------------------------------------------------===// 1241 1242 /// This class is used to read a bytecode buffer and translate it into MLIR. 1243 class mlir::BytecodeReader::Impl { 1244 struct RegionReadState; 1245 using LazyLoadableOpsInfo = 1246 std::list<std::pair<Operation *, RegionReadState>>; 1247 using LazyLoadableOpsMap = 1248 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1249 1250 public: 1251 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1252 llvm::MemoryBufferRef buffer, 1253 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1254 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1255 attrTypeReader(stringReader, resourceReader, fileLoc), 1256 // Use the builtin unrealized conversion cast operation to represent 1257 // forward references to values that aren't yet defined. 1258 forwardRefOpState(UnknownLoc::get(config.getContext()), 1259 "builtin.unrealized_conversion_cast", ValueRange(), 1260 NoneType::get(config.getContext())), 1261 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1262 1263 /// Read the bytecode defined within `buffer` into the given block. 1264 LogicalResult read(Block *block, 1265 llvm::function_ref<bool(Operation *)> lazyOps); 1266 1267 /// Return the number of ops that haven't been materialized yet. 1268 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1269 1270 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1271 1272 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1273 /// newly found lazy operation. 1274 LogicalResult 1275 materialize(Operation *op, 1276 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1277 this->lazyOpsCallback = lazyOpsCallback; 1278 auto resetlazyOpsCallback = 1279 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1280 auto it = lazyLoadableOpsMap.find(op); 1281 assert(it != lazyLoadableOpsMap.end() && 1282 "materialize called on non-materializable op"); 1283 return materialize(it); 1284 } 1285 1286 /// Materialize all operations. 1287 LogicalResult materializeAll() { 1288 while (!lazyLoadableOpsMap.empty()) { 1289 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1290 return failure(); 1291 } 1292 return success(); 1293 } 1294 1295 /// Finalize the lazy-loading by calling back with every op that hasn't been 1296 /// materialized to let the client decide if the op should be deleted or 1297 /// materialized. The op is materialized if the callback returns true, deleted 1298 /// otherwise. 1299 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1300 while (!lazyLoadableOps.empty()) { 1301 Operation *op = lazyLoadableOps.begin()->first; 1302 if (shouldMaterialize(op)) { 1303 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1304 return failure(); 1305 continue; 1306 } 1307 op->dropAllReferences(); 1308 op->erase(); 1309 lazyLoadableOps.pop_front(); 1310 lazyLoadableOpsMap.erase(op); 1311 } 1312 return success(); 1313 } 1314 1315 private: 1316 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1317 assert(it != lazyLoadableOpsMap.end() && 1318 "materialize called on non-materializable op"); 1319 valueScopes.emplace_back(); 1320 std::vector<RegionReadState> regionStack; 1321 regionStack.push_back(std::move(it->getSecond()->second)); 1322 lazyLoadableOps.erase(it->getSecond()); 1323 lazyLoadableOpsMap.erase(it); 1324 1325 while (!regionStack.empty()) 1326 if (failed(parseRegions(regionStack, regionStack.back()))) 1327 return failure(); 1328 return success(); 1329 } 1330 1331 /// Return the context for this config. 1332 MLIRContext *getContext() const { return config.getContext(); } 1333 1334 /// Parse the bytecode version. 1335 LogicalResult parseVersion(EncodingReader &reader); 1336 1337 //===--------------------------------------------------------------------===// 1338 // Dialect Section 1339 1340 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1341 1342 /// Parse an operation name reference using the given reader, and set the 1343 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1344 /// context where opName was registered. 1345 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1346 std::optional<bool> &wasRegistered); 1347 1348 //===--------------------------------------------------------------------===// 1349 // Attribute/Type Section 1350 1351 /// Parse an attribute or type using the given reader. 1352 template <typename T> 1353 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1354 return attrTypeReader.parseAttribute(reader, result); 1355 } 1356 LogicalResult parseType(EncodingReader &reader, Type &result) { 1357 return attrTypeReader.parseType(reader, result); 1358 } 1359 1360 //===--------------------------------------------------------------------===// 1361 // Resource Section 1362 1363 LogicalResult 1364 parseResourceSection(EncodingReader &reader, 1365 std::optional<ArrayRef<uint8_t>> resourceData, 1366 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1367 1368 //===--------------------------------------------------------------------===// 1369 // IR Section 1370 1371 /// This struct represents the current read state of a range of regions. This 1372 /// struct is used to enable iterative parsing of regions. 1373 struct RegionReadState { 1374 RegionReadState(Operation *op, EncodingReader *reader, 1375 bool isIsolatedFromAbove) 1376 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1377 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1378 bool isIsolatedFromAbove) 1379 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1380 isIsolatedFromAbove(isIsolatedFromAbove) {} 1381 1382 /// The current regions being read. 1383 MutableArrayRef<Region>::iterator curRegion, endRegion; 1384 /// This is the reader to use for this region, this pointer is pointing to 1385 /// the parent region reader unless the current region is IsolatedFromAbove, 1386 /// in which case the pointer is pointing to the `owningReader` which is a 1387 /// section dedicated to the current region. 1388 EncodingReader *reader; 1389 std::unique_ptr<EncodingReader> owningReader; 1390 1391 /// The number of values defined immediately within this region. 1392 unsigned numValues = 0; 1393 1394 /// The current blocks of the region being read. 1395 SmallVector<Block *> curBlocks; 1396 Region::iterator curBlock = {}; 1397 1398 /// The number of operations remaining to be read from the current block 1399 /// being read. 1400 uint64_t numOpsRemaining = 0; 1401 1402 /// A flag indicating if the regions being read are isolated from above. 1403 bool isIsolatedFromAbove = false; 1404 }; 1405 1406 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1407 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1408 RegionReadState &readState); 1409 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1410 RegionReadState &readState, 1411 bool &isIsolatedFromAbove); 1412 1413 LogicalResult parseRegion(RegionReadState &readState); 1414 LogicalResult parseBlockHeader(EncodingReader &reader, 1415 RegionReadState &readState); 1416 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1417 1418 //===--------------------------------------------------------------------===// 1419 // Value Processing 1420 1421 /// Parse an operand reference using the given reader. Returns nullptr in the 1422 /// case of failure. 1423 Value parseOperand(EncodingReader &reader); 1424 1425 /// Sequentially define the given value range. 1426 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1427 1428 /// Create a value to use for a forward reference. 1429 Value createForwardRef(); 1430 1431 //===--------------------------------------------------------------------===// 1432 // Use-list order helpers 1433 1434 /// This struct is a simple storage that contains information required to 1435 /// reorder the use-list of a value with respect to the pre-order traversal 1436 /// ordering. 1437 struct UseListOrderStorage { 1438 UseListOrderStorage(bool isIndexPairEncoding, 1439 SmallVector<unsigned, 4> &&indices) 1440 : indices(std::move(indices)), 1441 isIndexPairEncoding(isIndexPairEncoding){}; 1442 /// The vector containing the information required to reorder the 1443 /// use-list of a value. 1444 SmallVector<unsigned, 4> indices; 1445 1446 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1447 /// indexing, such as `dst = order[src]`. 1448 bool isIndexPairEncoding; 1449 }; 1450 1451 /// Parse use-list order from bytecode for a range of values if available. The 1452 /// range is expected to be either a block argument or an op result range. On 1453 /// success, return a map of the position in the range and the use-list order 1454 /// encoding. The function assumes to know the size of the range it is 1455 /// processing. 1456 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1457 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1458 uint64_t rangeSize); 1459 1460 /// Shuffle the use-chain according to the order parsed. 1461 LogicalResult sortUseListOrder(Value value); 1462 1463 /// Recursively visit all the values defined within topLevelOp and sort the 1464 /// use-list orders according to the indices parsed. 1465 LogicalResult processUseLists(Operation *topLevelOp); 1466 1467 //===--------------------------------------------------------------------===// 1468 // Fields 1469 1470 /// This class represents a single value scope, in which a value scope is 1471 /// delimited by isolated from above regions. 1472 struct ValueScope { 1473 /// Push a new region state onto this scope, reserving enough values for 1474 /// those defined within the current region of the provided state. 1475 void push(RegionReadState &readState) { 1476 nextValueIDs.push_back(values.size()); 1477 values.resize(values.size() + readState.numValues); 1478 } 1479 1480 /// Pop the values defined for the current region within the provided region 1481 /// state. 1482 void pop(RegionReadState &readState) { 1483 values.resize(values.size() - readState.numValues); 1484 nextValueIDs.pop_back(); 1485 } 1486 1487 /// The set of values defined in this scope. 1488 std::vector<Value> values; 1489 1490 /// The ID for the next defined value for each region current being 1491 /// processed in this scope. 1492 SmallVector<unsigned, 4> nextValueIDs; 1493 }; 1494 1495 /// The configuration of the parser. 1496 const ParserConfig &config; 1497 1498 /// A location to use when emitting errors. 1499 Location fileLoc; 1500 1501 /// Flag that indicates if lazyloading is enabled. 1502 bool lazyLoading; 1503 1504 /// Keep track of operations that have been lazy loaded (their regions haven't 1505 /// been materialized), along with the `RegionReadState` that allows to 1506 /// lazy-load the regions nested under the operation. 1507 LazyLoadableOpsInfo lazyLoadableOps; 1508 LazyLoadableOpsMap lazyLoadableOpsMap; 1509 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1510 1511 /// The reader used to process attribute and types within the bytecode. 1512 AttrTypeReader attrTypeReader; 1513 1514 /// The version of the bytecode being read. 1515 uint64_t version = 0; 1516 1517 /// The producer of the bytecode being read. 1518 StringRef producer; 1519 1520 /// The table of IR units referenced within the bytecode file. 1521 SmallVector<BytecodeDialect> dialects; 1522 SmallVector<BytecodeOperationName> opNames; 1523 1524 /// The reader used to process resources within the bytecode. 1525 ResourceSectionReader resourceReader; 1526 1527 /// Worklist of values with custom use-list orders to process before the end 1528 /// of the parsing. 1529 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1530 1531 /// The table of strings referenced within the bytecode file. 1532 StringSectionReader stringReader; 1533 1534 /// The table of properties referenced by the operation in the bytecode file. 1535 PropertiesSectionReader propertiesReader; 1536 1537 /// The current set of available IR value scopes. 1538 std::vector<ValueScope> valueScopes; 1539 1540 /// The global pre-order operation ordering. 1541 DenseMap<Operation *, unsigned> operationIDs; 1542 1543 /// A block containing the set of operations defined to create forward 1544 /// references. 1545 Block forwardRefOps; 1546 1547 /// A block containing previously created, and no longer used, forward 1548 /// reference operations. 1549 Block openForwardRefOps; 1550 1551 /// An operation state used when instantiating forward references. 1552 OperationState forwardRefOpState; 1553 1554 /// Reference to the input buffer. 1555 llvm::MemoryBufferRef buffer; 1556 1557 /// The optional owning source manager, which when present may be used to 1558 /// extend the lifetime of the input buffer. 1559 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1560 }; 1561 1562 LogicalResult BytecodeReader::Impl::read( 1563 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1564 EncodingReader reader(buffer.getBuffer(), fileLoc); 1565 this->lazyOpsCallback = lazyOpsCallback; 1566 auto resetlazyOpsCallback = 1567 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1568 1569 // Skip over the bytecode header, this should have already been checked. 1570 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1571 return failure(); 1572 // Parse the bytecode version and producer. 1573 if (failed(parseVersion(reader)) || 1574 failed(reader.parseNullTerminatedString(producer))) 1575 return failure(); 1576 1577 // Add a diagnostic handler that attaches a note that includes the original 1578 // producer of the bytecode. 1579 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1580 diag.attachNote() << "in bytecode version " << version 1581 << " produced by: " << producer; 1582 return failure(); 1583 }); 1584 1585 // Parse the raw data for each of the top-level sections of the bytecode. 1586 std::optional<ArrayRef<uint8_t>> 1587 sectionDatas[bytecode::Section::kNumSections]; 1588 while (!reader.empty()) { 1589 // Read the next section from the bytecode. 1590 bytecode::Section::ID sectionID; 1591 ArrayRef<uint8_t> sectionData; 1592 if (failed(reader.parseSection(sectionID, sectionData))) 1593 return failure(); 1594 1595 // Check for duplicate sections, we only expect one instance of each. 1596 if (sectionDatas[sectionID]) { 1597 return reader.emitError("duplicate top-level section: ", 1598 ::toString(sectionID)); 1599 } 1600 sectionDatas[sectionID] = sectionData; 1601 } 1602 // Check that all of the required sections were found. 1603 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1604 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1605 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1606 return reader.emitError("missing data for top-level section: ", 1607 ::toString(sectionID)); 1608 } 1609 } 1610 1611 // Process the string section first. 1612 if (failed(stringReader.initialize( 1613 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1614 return failure(); 1615 1616 // Process the properties section. 1617 if (sectionDatas[bytecode::Section::kProperties] && 1618 failed(propertiesReader.initialize( 1619 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1620 return failure(); 1621 1622 // Process the dialect section. 1623 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1624 return failure(); 1625 1626 // Process the resource section if present. 1627 if (failed(parseResourceSection( 1628 reader, sectionDatas[bytecode::Section::kResource], 1629 sectionDatas[bytecode::Section::kResourceOffset]))) 1630 return failure(); 1631 1632 // Process the attribute and type section. 1633 if (failed(attrTypeReader.initialize( 1634 dialects, *sectionDatas[bytecode::Section::kAttrType], 1635 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1636 return failure(); 1637 1638 // Finally, process the IR section. 1639 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1640 } 1641 1642 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1643 if (failed(reader.parseVarInt(version))) 1644 return failure(); 1645 1646 // Validate the bytecode version. 1647 uint64_t currentVersion = bytecode::kVersion; 1648 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1649 if (version < minSupportedVersion) { 1650 return reader.emitError("bytecode version ", version, 1651 " is older than the current version of ", 1652 currentVersion, ", and upgrade is not supported"); 1653 } 1654 if (version > currentVersion) { 1655 return reader.emitError("bytecode version ", version, 1656 " is newer than the current version ", 1657 currentVersion); 1658 } 1659 // Override any request to lazy-load if the bytecode version is too old. 1660 if (version < bytecode::kLazyLoading) 1661 lazyLoading = false; 1662 return success(); 1663 } 1664 1665 //===----------------------------------------------------------------------===// 1666 // Dialect Section 1667 1668 LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { 1669 if (dialect) 1670 return success(); 1671 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1672 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1673 return reader.emitError("dialect '") 1674 << name 1675 << "' is unknown. If this is intended, please call " 1676 "allowUnregisteredDialects() on the MLIRContext, or use " 1677 "-allow-unregistered-dialect with the MLIR tool used."; 1678 } 1679 dialect = loadedDialect; 1680 1681 // If the dialect was actually loaded, check to see if it has a bytecode 1682 // interface. 1683 if (loadedDialect) 1684 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1685 if (!versionBuffer.empty()) { 1686 if (!interface) 1687 return reader.emitError("dialect '") 1688 << name 1689 << "' does not implement the bytecode interface, " 1690 "but found a version entry"; 1691 EncodingReader encReader(versionBuffer, reader.getLoc()); 1692 DialectReader versionReader = reader.withEncodingReader(encReader); 1693 loadedVersion = interface->readVersion(versionReader); 1694 if (!loadedVersion) 1695 return failure(); 1696 } 1697 return success(); 1698 } 1699 1700 LogicalResult 1701 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1702 EncodingReader sectionReader(sectionData, fileLoc); 1703 1704 // Parse the number of dialects in the section. 1705 uint64_t numDialects; 1706 if (failed(sectionReader.parseVarInt(numDialects))) 1707 return failure(); 1708 dialects.resize(numDialects); 1709 1710 // Parse each of the dialects. 1711 for (uint64_t i = 0; i < numDialects; ++i) { 1712 /// Before version kDialectVersioning, there wasn't any versioning available 1713 /// for dialects, and the entryIdx represent the string itself. 1714 if (version < bytecode::kDialectVersioning) { 1715 if (failed(stringReader.parseString(sectionReader, dialects[i].name))) 1716 return failure(); 1717 continue; 1718 } 1719 // Parse ID representing dialect and version. 1720 uint64_t dialectNameIdx; 1721 bool versionAvailable; 1722 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1723 versionAvailable))) 1724 return failure(); 1725 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1726 dialects[i].name))) 1727 return failure(); 1728 if (versionAvailable) { 1729 bytecode::Section::ID sectionID; 1730 if (failed( 1731 sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) 1732 return failure(); 1733 if (sectionID != bytecode::Section::kDialectVersions) { 1734 emitError(fileLoc, "expected dialect version section"); 1735 return failure(); 1736 } 1737 } 1738 } 1739 1740 // Parse the operation names, which are grouped by dialect. 1741 auto parseOpName = [&](BytecodeDialect *dialect) { 1742 StringRef opName; 1743 std::optional<bool> wasRegistered; 1744 // Prior to version kNativePropertiesEncoding, the information about wheter 1745 // an op was registered or not wasn't encoded. 1746 if (version < bytecode::kNativePropertiesEncoding) { 1747 if (failed(stringReader.parseString(sectionReader, opName))) 1748 return failure(); 1749 } else { 1750 bool wasRegisteredFlag; 1751 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1752 wasRegisteredFlag))) 1753 return failure(); 1754 wasRegistered = wasRegisteredFlag; 1755 } 1756 opNames.emplace_back(dialect, opName, wasRegistered); 1757 return success(); 1758 }; 1759 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation 1760 // where the number of ops are known. 1761 if (version >= bytecode::kElideUnknownBlockArgLocation) { 1762 uint64_t numOps; 1763 if (failed(sectionReader.parseVarInt(numOps))) 1764 return failure(); 1765 opNames.reserve(numOps); 1766 } 1767 while (!sectionReader.empty()) 1768 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1769 return failure(); 1770 return success(); 1771 } 1772 1773 FailureOr<OperationName> 1774 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1775 std::optional<bool> &wasRegistered) { 1776 BytecodeOperationName *opName = nullptr; 1777 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1778 return failure(); 1779 wasRegistered = opName->wasRegistered; 1780 // Check to see if this operation name has already been resolved. If we 1781 // haven't, load the dialect and build the operation name. 1782 if (!opName->opName) { 1783 // Load the dialect and its version. 1784 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1785 reader); 1786 if (failed(opName->dialect->load(dialectReader, getContext()))) 1787 return failure(); 1788 // If the opName is empty, this is because we use to accept names such as 1789 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1790 // format anymore but for now we'll be backward compatible. This can only 1791 // happen with unregistered dialects. 1792 if (opName->name.empty()) { 1793 if (opName->dialect->getLoadedDialect()) 1794 return emitError(fileLoc) << "has an empty opname for dialect '" 1795 << opName->dialect->name << "'\n"; 1796 1797 opName->opName.emplace(opName->dialect->name, getContext()); 1798 } else { 1799 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1800 getContext()); 1801 } 1802 } 1803 return *opName->opName; 1804 } 1805 1806 //===----------------------------------------------------------------------===// 1807 // Resource Section 1808 1809 LogicalResult BytecodeReader::Impl::parseResourceSection( 1810 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1811 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1812 // Ensure both sections are either present or not. 1813 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1814 if (resourceOffsetData) 1815 return emitError(fileLoc, "unexpected resource offset section when " 1816 "resource section is not present"); 1817 return emitError( 1818 fileLoc, 1819 "expected resource offset section when resource section is present"); 1820 } 1821 1822 // If the resource sections are absent, there is nothing to do. 1823 if (!resourceData) 1824 return success(); 1825 1826 // Initialize the resource reader with the resource sections. 1827 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1828 reader); 1829 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1830 *resourceData, *resourceOffsetData, 1831 dialectReader, bufferOwnerRef); 1832 } 1833 1834 //===----------------------------------------------------------------------===// 1835 // UseListOrder Helpers 1836 1837 FailureOr<BytecodeReader::Impl::UseListMapT> 1838 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1839 uint64_t numResults) { 1840 BytecodeReader::Impl::UseListMapT map; 1841 uint64_t numValuesToRead = 1; 1842 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1843 return failure(); 1844 1845 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1846 uint64_t resultIdx = 0; 1847 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1848 return failure(); 1849 1850 uint64_t numValues; 1851 bool indexPairEncoding; 1852 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1853 return failure(); 1854 1855 SmallVector<unsigned, 4> useListOrders; 1856 for (size_t idx = 0; idx < numValues; idx++) { 1857 uint64_t index; 1858 if (failed(reader.parseVarInt(index))) 1859 return failure(); 1860 useListOrders.push_back(index); 1861 } 1862 1863 // Store in a map the result index 1864 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1865 std::move(useListOrders))); 1866 } 1867 1868 return map; 1869 } 1870 1871 /// Sorts each use according to the order specified in the use-list parsed. If 1872 /// the custom use-list is not found, this means that the order needs to be 1873 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1874 /// on the same operation, the order will follow the reverse operand number 1875 /// ordering. 1876 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1877 // Early return for trivial use-lists. 1878 if (value.use_empty() || value.hasOneUse()) 1879 return success(); 1880 1881 bool hasIncomingOrder = 1882 valueToUseListMap.contains(value.getAsOpaquePointer()); 1883 1884 // Compute the current order of the use-list with respect to the global 1885 // ordering. Detect if the order is already sorted while doing so. 1886 bool alreadySorted = true; 1887 auto &firstUse = *value.use_begin(); 1888 uint64_t prevID = 1889 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1890 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1891 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1892 uint64_t currentID = bytecode::getUseID( 1893 item.value(), operationIDs.at(item.value().getOwner())); 1894 alreadySorted &= prevID > currentID; 1895 currentOrder.push_back({item.index(), currentID}); 1896 prevID = currentID; 1897 } 1898 1899 // If the order is already sorted, and there wasn't a custom order to apply 1900 // from the bytecode file, we are done. 1901 if (alreadySorted && !hasIncomingOrder) 1902 return success(); 1903 1904 // If not already sorted, sort the indices of the current order by descending 1905 // useIDs. 1906 if (!alreadySorted) 1907 std::sort( 1908 currentOrder.begin(), currentOrder.end(), 1909 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1910 1911 if (!hasIncomingOrder) { 1912 // If the bytecode file did not contain any custom use-list order, it means 1913 // that the order was descending useID. Hence, shuffle by the first index 1914 // of the `currentOrder` pair. 1915 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1916 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1917 value.shuffleUseList(shuffle); 1918 return success(); 1919 } 1920 1921 // Pull the custom order info from the map. 1922 UseListOrderStorage customOrder = 1923 valueToUseListMap.at(value.getAsOpaquePointer()); 1924 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1925 uint64_t numUses = 1926 std::distance(value.getUses().begin(), value.getUses().end()); 1927 1928 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1929 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1930 // as identity, and then apply the mapping encoded in the indices. 1931 if (customOrder.isIndexPairEncoding) { 1932 // Return failure if the number of indices was not representing pairs. 1933 if (shuffle.size() & 1) 1934 return failure(); 1935 1936 SmallVector<unsigned, 4> newShuffle(numUses); 1937 size_t idx = 0; 1938 std::iota(newShuffle.begin(), newShuffle.end(), idx); 1939 for (idx = 0; idx < shuffle.size(); idx += 2) 1940 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 1941 1942 shuffle = std::move(newShuffle); 1943 } 1944 1945 // Make sure that the indices represent a valid mapping. That is, the sum of 1946 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 1947 // duplicates are allowed in the list. 1948 DenseSet<unsigned> set; 1949 uint64_t accumulator = 0; 1950 for (const auto &elem : shuffle) { 1951 if (set.contains(elem)) 1952 return failure(); 1953 accumulator += elem; 1954 set.insert(elem); 1955 } 1956 if (numUses != shuffle.size() || 1957 accumulator != (((numUses - 1) * numUses) >> 1)) 1958 return failure(); 1959 1960 // Apply the current ordering map onto the shuffle vector to get the final 1961 // use-list sorting indices before shuffling. 1962 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 1963 currentOrder, [&](auto item) { return shuffle[item.first]; })); 1964 value.shuffleUseList(shuffle); 1965 return success(); 1966 } 1967 1968 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 1969 // Precompute operation IDs according to the pre-order walk of the IR. We 1970 // can't do this while parsing since parseRegions ordering is not strictly 1971 // equal to the pre-order walk. 1972 unsigned operationID = 0; 1973 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 1974 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 1975 1976 auto blockWalk = topLevelOp->walk([this](Block *block) { 1977 for (auto arg : block->getArguments()) 1978 if (failed(sortUseListOrder(arg))) 1979 return WalkResult::interrupt(); 1980 return WalkResult::advance(); 1981 }); 1982 1983 auto resultWalk = topLevelOp->walk([this](Operation *op) { 1984 for (auto result : op->getResults()) 1985 if (failed(sortUseListOrder(result))) 1986 return WalkResult::interrupt(); 1987 return WalkResult::advance(); 1988 }); 1989 1990 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 1991 } 1992 1993 //===----------------------------------------------------------------------===// 1994 // IR Section 1995 1996 LogicalResult 1997 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 1998 Block *block) { 1999 EncodingReader reader(sectionData, fileLoc); 2000 2001 // A stack of operation regions currently being read from the bytecode. 2002 std::vector<RegionReadState> regionStack; 2003 2004 // Parse the top-level block using a temporary module operation. 2005 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 2006 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 2007 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 2008 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 2009 if (failed(parseBlockHeader(reader, regionStack.back()))) 2010 return failure(); 2011 valueScopes.emplace_back(); 2012 valueScopes.back().push(regionStack.back()); 2013 2014 // Iteratively parse regions until everything has been resolved. 2015 while (!regionStack.empty()) 2016 if (failed(parseRegions(regionStack, regionStack.back()))) 2017 return failure(); 2018 if (!forwardRefOps.empty()) { 2019 return reader.emitError( 2020 "not all forward unresolved forward operand references"); 2021 } 2022 2023 // Sort use-lists according to what specified in bytecode. 2024 if (failed(processUseLists(*moduleOp))) 2025 return reader.emitError( 2026 "parsed use-list orders were invalid and could not be applied"); 2027 2028 // Resolve dialect version. 2029 for (const BytecodeDialect &byteCodeDialect : dialects) { 2030 // Parsing is complete, give an opportunity to each dialect to visit the 2031 // IR and perform upgrades. 2032 if (!byteCodeDialect.loadedVersion) 2033 continue; 2034 if (byteCodeDialect.interface && 2035 failed(byteCodeDialect.interface->upgradeFromVersion( 2036 *moduleOp, *byteCodeDialect.loadedVersion))) 2037 return failure(); 2038 } 2039 2040 // Verify that the parsed operations are valid. 2041 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2042 return failure(); 2043 2044 // Splice the parsed operations over to the provided top-level block. 2045 auto &parsedOps = moduleOp->getBody()->getOperations(); 2046 auto &destOps = block->getOperations(); 2047 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2048 return success(); 2049 } 2050 2051 LogicalResult 2052 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2053 RegionReadState &readState) { 2054 // Process regions, blocks, and operations until the end or if a nested 2055 // region is encountered. In this case we push a new state in regionStack and 2056 // return, the processing of the current region will resume afterward. 2057 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2058 // If the current block hasn't been setup yet, parse the header for this 2059 // region. The current block is already setup when this function was 2060 // interrupted to recurse down in a nested region and we resume the current 2061 // block after processing the nested region. 2062 if (readState.curBlock == Region::iterator()) { 2063 if (failed(parseRegion(readState))) 2064 return failure(); 2065 2066 // If the region is empty, there is nothing to more to do. 2067 if (readState.curRegion->empty()) 2068 continue; 2069 } 2070 2071 // Parse the blocks within the region. 2072 EncodingReader &reader = *readState.reader; 2073 do { 2074 while (readState.numOpsRemaining--) { 2075 // Read in the next operation. We don't read its regions directly, we 2076 // handle those afterwards as necessary. 2077 bool isIsolatedFromAbove = false; 2078 FailureOr<Operation *> op = 2079 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2080 if (failed(op)) 2081 return failure(); 2082 2083 // If the op has regions, add it to the stack for processing and return: 2084 // we stop the processing of the current region and resume it after the 2085 // inner one is completed. Unless LazyLoading is activated in which case 2086 // nested region parsing is delayed. 2087 if ((*op)->getNumRegions()) { 2088 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2089 2090 // Isolated regions are encoded as a section in version 2 and above. 2091 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { 2092 bytecode::Section::ID sectionID; 2093 ArrayRef<uint8_t> sectionData; 2094 if (failed(reader.parseSection(sectionID, sectionData))) 2095 return failure(); 2096 if (sectionID != bytecode::Section::kIR) 2097 return emitError(fileLoc, "expected IR section for region"); 2098 childState.owningReader = 2099 std::make_unique<EncodingReader>(sectionData, fileLoc); 2100 childState.reader = childState.owningReader.get(); 2101 2102 // If the user has a callback set, they have the opportunity to 2103 // control lazyloading as we go. 2104 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) { 2105 lazyLoadableOps.emplace_back(*op, std::move(childState)); 2106 lazyLoadableOpsMap.try_emplace(*op, 2107 std::prev(lazyLoadableOps.end())); 2108 continue; 2109 } 2110 } 2111 regionStack.push_back(std::move(childState)); 2112 2113 // If the op is isolated from above, push a new value scope. 2114 if (isIsolatedFromAbove) 2115 valueScopes.emplace_back(); 2116 return success(); 2117 } 2118 } 2119 2120 // Move to the next block of the region. 2121 if (++readState.curBlock == readState.curRegion->end()) 2122 break; 2123 if (failed(parseBlockHeader(reader, readState))) 2124 return failure(); 2125 } while (true); 2126 2127 // Reset the current block and any values reserved for this region. 2128 readState.curBlock = {}; 2129 valueScopes.back().pop(readState); 2130 } 2131 2132 // When the regions have been fully parsed, pop them off of the read stack. If 2133 // the regions were isolated from above, we also pop the last value scope. 2134 if (readState.isIsolatedFromAbove) { 2135 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2136 valueScopes.pop_back(); 2137 } 2138 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2139 regionStack.pop_back(); 2140 return success(); 2141 } 2142 2143 FailureOr<Operation *> 2144 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2145 RegionReadState &readState, 2146 bool &isIsolatedFromAbove) { 2147 // Parse the name of the operation. 2148 std::optional<bool> wasRegistered; 2149 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2150 if (failed(opName)) 2151 return failure(); 2152 2153 // Parse the operation mask, which indicates which components of the operation 2154 // are present. 2155 uint8_t opMask; 2156 if (failed(reader.parseByte(opMask))) 2157 return failure(); 2158 2159 /// Parse the location. 2160 LocationAttr opLoc; 2161 if (failed(parseAttribute(reader, opLoc))) 2162 return failure(); 2163 2164 // With the location and name resolved, we can start building the operation 2165 // state. 2166 OperationState opState(opLoc, *opName); 2167 2168 // Parse the attributes of the operation. 2169 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2170 DictionaryAttr dictAttr; 2171 if (failed(parseAttribute(reader, dictAttr))) 2172 return failure(); 2173 opState.attributes = dictAttr; 2174 } 2175 2176 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2177 // kHasProperties wasn't emitted in older bytecode, we should never get 2178 // there without also having the `wasRegistered` flag available. 2179 if (!wasRegistered) 2180 return emitError(fileLoc, 2181 "Unexpected missing `wasRegistered` opname flag at " 2182 "bytecode version ") 2183 << version << " with properties."; 2184 // When an operation is emitted without being registered, the properties are 2185 // stored as an attribute. Otherwise the op must implement the bytecode 2186 // interface and control the serialization. 2187 if (wasRegistered) { 2188 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2189 reader); 2190 if (failed( 2191 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2192 return failure(); 2193 } else { 2194 // If the operation wasn't registered when it was emitted, the properties 2195 // was serialized as an attribute. 2196 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2197 return failure(); 2198 } 2199 } 2200 2201 /// Parse the results of the operation. 2202 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2203 uint64_t numResults; 2204 if (failed(reader.parseVarInt(numResults))) 2205 return failure(); 2206 opState.types.resize(numResults); 2207 for (int i = 0, e = numResults; i < e; ++i) 2208 if (failed(parseType(reader, opState.types[i]))) 2209 return failure(); 2210 } 2211 2212 /// Parse the operands of the operation. 2213 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2214 uint64_t numOperands; 2215 if (failed(reader.parseVarInt(numOperands))) 2216 return failure(); 2217 opState.operands.resize(numOperands); 2218 for (int i = 0, e = numOperands; i < e; ++i) 2219 if (!(opState.operands[i] = parseOperand(reader))) 2220 return failure(); 2221 } 2222 2223 /// Parse the successors of the operation. 2224 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2225 uint64_t numSuccs; 2226 if (failed(reader.parseVarInt(numSuccs))) 2227 return failure(); 2228 opState.successors.resize(numSuccs); 2229 for (int i = 0, e = numSuccs; i < e; ++i) { 2230 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2231 "successor"))) 2232 return failure(); 2233 } 2234 } 2235 2236 /// Parse the use-list orders for the results of the operation. Use-list 2237 /// orders are available since version 3 of the bytecode. 2238 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2239 if (version >= bytecode::kUseListOrdering && 2240 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2241 size_t numResults = opState.types.size(); 2242 auto parseResult = parseUseListOrderForRange(reader, numResults); 2243 if (failed(parseResult)) 2244 return failure(); 2245 resultIdxToUseListMap = std::move(*parseResult); 2246 } 2247 2248 /// Parse the regions of the operation. 2249 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2250 uint64_t numRegions; 2251 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2252 return failure(); 2253 2254 opState.regions.reserve(numRegions); 2255 for (int i = 0, e = numRegions; i < e; ++i) 2256 opState.regions.push_back(std::make_unique<Region>()); 2257 } 2258 2259 // Create the operation at the back of the current block. 2260 Operation *op = Operation::create(opState); 2261 readState.curBlock->push_back(op); 2262 2263 // If the operation had results, update the value references. 2264 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2265 return failure(); 2266 2267 /// Store a map for every value that received a custom use-list order from the 2268 /// bytecode file. 2269 if (resultIdxToUseListMap.has_value()) { 2270 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2271 if (resultIdxToUseListMap->contains(idx)) { 2272 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2273 resultIdxToUseListMap->at(idx)); 2274 } 2275 } 2276 } 2277 return op; 2278 } 2279 2280 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2281 EncodingReader &reader = *readState.reader; 2282 2283 // Parse the number of blocks in the region. 2284 uint64_t numBlocks; 2285 if (failed(reader.parseVarInt(numBlocks))) 2286 return failure(); 2287 2288 // If the region is empty, there is nothing else to do. 2289 if (numBlocks == 0) 2290 return success(); 2291 2292 // Parse the number of values defined in this region. 2293 uint64_t numValues; 2294 if (failed(reader.parseVarInt(numValues))) 2295 return failure(); 2296 readState.numValues = numValues; 2297 2298 // Create the blocks within this region. We do this before processing so that 2299 // we can rely on the blocks existing when creating operations. 2300 readState.curBlocks.clear(); 2301 readState.curBlocks.reserve(numBlocks); 2302 for (uint64_t i = 0; i < numBlocks; ++i) { 2303 readState.curBlocks.push_back(new Block()); 2304 readState.curRegion->push_back(readState.curBlocks.back()); 2305 } 2306 2307 // Prepare the current value scope for this region. 2308 valueScopes.back().push(readState); 2309 2310 // Parse the entry block of the region. 2311 readState.curBlock = readState.curRegion->begin(); 2312 return parseBlockHeader(reader, readState); 2313 } 2314 2315 LogicalResult 2316 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2317 RegionReadState &readState) { 2318 bool hasArgs; 2319 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2320 return failure(); 2321 2322 // Parse the arguments of the block. 2323 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2324 return failure(); 2325 2326 // Uselist orders are available since version 3 of the bytecode. 2327 if (version < bytecode::kUseListOrdering) 2328 return success(); 2329 2330 uint8_t hasUseListOrders = 0; 2331 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2332 return failure(); 2333 2334 if (!hasUseListOrders) 2335 return success(); 2336 2337 Block &blk = *readState.curBlock; 2338 auto argIdxToUseListMap = 2339 parseUseListOrderForRange(reader, blk.getNumArguments()); 2340 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2341 return failure(); 2342 2343 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2344 if (argIdxToUseListMap->contains(idx)) 2345 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2346 argIdxToUseListMap->at(idx)); 2347 2348 // We don't parse the operations of the block here, that's done elsewhere. 2349 return success(); 2350 } 2351 2352 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2353 Block *block) { 2354 // Parse the value ID for the first argument, and the number of arguments. 2355 uint64_t numArgs; 2356 if (failed(reader.parseVarInt(numArgs))) 2357 return failure(); 2358 2359 SmallVector<Type> argTypes; 2360 SmallVector<Location> argLocs; 2361 argTypes.reserve(numArgs); 2362 argLocs.reserve(numArgs); 2363 2364 Location unknownLoc = UnknownLoc::get(config.getContext()); 2365 while (numArgs--) { 2366 Type argType; 2367 LocationAttr argLoc = unknownLoc; 2368 if (version >= bytecode::kElideUnknownBlockArgLocation) { 2369 // Parse the type with hasLoc flag to determine if it has type. 2370 uint64_t typeIdx; 2371 bool hasLoc; 2372 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2373 !(argType = attrTypeReader.resolveType(typeIdx))) 2374 return failure(); 2375 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2376 return failure(); 2377 } else { 2378 // All args has type and location. 2379 if (failed(parseType(reader, argType)) || 2380 failed(parseAttribute(reader, argLoc))) 2381 return failure(); 2382 } 2383 argTypes.push_back(argType); 2384 argLocs.push_back(argLoc); 2385 } 2386 block->addArguments(argTypes, argLocs); 2387 return defineValues(reader, block->getArguments()); 2388 } 2389 2390 //===----------------------------------------------------------------------===// 2391 // Value Processing 2392 2393 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2394 std::vector<Value> &values = valueScopes.back().values; 2395 Value *value = nullptr; 2396 if (failed(parseEntry(reader, values, value, "value"))) 2397 return Value(); 2398 2399 // Create a new forward reference if necessary. 2400 if (!*value) 2401 *value = createForwardRef(); 2402 return *value; 2403 } 2404 2405 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2406 ValueRange newValues) { 2407 ValueScope &valueScope = valueScopes.back(); 2408 std::vector<Value> &values = valueScope.values; 2409 2410 unsigned &valueID = valueScope.nextValueIDs.back(); 2411 unsigned valueIDEnd = valueID + newValues.size(); 2412 if (valueIDEnd > values.size()) { 2413 return reader.emitError( 2414 "value index range was outside of the expected range for " 2415 "the parent region, got [", 2416 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2417 values.size() - 1); 2418 } 2419 2420 // Assign the values and update any forward references. 2421 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2422 Value newValue = newValues[i]; 2423 2424 // Check to see if a definition for this value already exists. 2425 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2426 Operation *forwardRefOp = oldValue.getDefiningOp(); 2427 2428 // Assert that this is a forward reference operation. Given how we compute 2429 // definition ids (incrementally as we parse), it shouldn't be possible 2430 // for the value to be defined any other way. 2431 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2432 "value index was already defined?"); 2433 2434 oldValue.replaceAllUsesWith(newValue); 2435 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2436 } 2437 } 2438 return success(); 2439 } 2440 2441 Value BytecodeReader::Impl::createForwardRef() { 2442 // Check for an avaliable existing operation to use. Otherwise, create a new 2443 // fake operation to use for the reference. 2444 if (!openForwardRefOps.empty()) { 2445 Operation *op = &openForwardRefOps.back(); 2446 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2447 } else { 2448 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2449 } 2450 return forwardRefOps.back().getResult(0); 2451 } 2452 2453 //===----------------------------------------------------------------------===// 2454 // Entry Points 2455 //===----------------------------------------------------------------------===// 2456 2457 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2458 2459 BytecodeReader::BytecodeReader( 2460 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2461 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2462 Location sourceFileLoc = 2463 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2464 /*line=*/0, /*column=*/0); 2465 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2466 bufferOwnerRef); 2467 } 2468 2469 LogicalResult BytecodeReader::readTopLevel( 2470 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2471 return impl->read(block, lazyOpsCallback); 2472 } 2473 2474 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2475 return impl->getNumOpsToMaterialize(); 2476 } 2477 2478 bool BytecodeReader::isMaterializable(Operation *op) { 2479 return impl->isMaterializable(op); 2480 } 2481 2482 LogicalResult BytecodeReader::materialize( 2483 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2484 return impl->materialize(op, lazyOpsCallback); 2485 } 2486 2487 LogicalResult 2488 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2489 return impl->finalize(shouldMaterialize); 2490 } 2491 2492 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2493 return buffer.getBuffer().startswith("ML\xefR"); 2494 } 2495 2496 /// Read the bytecode from the provided memory buffer reference. 2497 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2498 /// and may be used to extend the lifetime of the buffer. 2499 static LogicalResult 2500 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2501 const ParserConfig &config, 2502 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2503 Location sourceFileLoc = 2504 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2505 /*line=*/0, /*column=*/0); 2506 if (!isBytecode(buffer)) { 2507 return emitError(sourceFileLoc, 2508 "input buffer is not an MLIR bytecode file"); 2509 } 2510 2511 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2512 buffer, bufferOwnerRef); 2513 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2514 } 2515 2516 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2517 const ParserConfig &config) { 2518 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2519 } 2520 LogicalResult 2521 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2522 Block *block, const ParserConfig &config) { 2523 return readBytecodeFileImpl( 2524 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2525 sourceMgr); 2526 } 2527