1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // TODO: Support for big-endian architectures. 10 11 #include "mlir/Bytecode/BytecodeReader.h" 12 #include "mlir/AsmParser/AsmParser.h" 13 #include "mlir/Bytecode/BytecodeImplementation.h" 14 #include "mlir/Bytecode/Encoding.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/Verifier.h" 20 #include "mlir/IR/Visitors.h" 21 #include "mlir/Support/LLVM.h" 22 #include "mlir/Support/LogicalResult.h" 23 #include "llvm/ADT/MapVector.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/Support/MemoryBufferRef.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/SourceMgr.h" 31 #include <list> 32 #include <memory> 33 #include <numeric> 34 #include <optional> 35 36 #define DEBUG_TYPE "mlir-bytecode-reader" 37 38 using namespace mlir; 39 40 /// Stringify the given section ID. 41 static std::string toString(bytecode::Section::ID sectionID) { 42 switch (sectionID) { 43 case bytecode::Section::kString: 44 return "String (0)"; 45 case bytecode::Section::kDialect: 46 return "Dialect (1)"; 47 case bytecode::Section::kAttrType: 48 return "AttrType (2)"; 49 case bytecode::Section::kAttrTypeOffset: 50 return "AttrTypeOffset (3)"; 51 case bytecode::Section::kIR: 52 return "IR (4)"; 53 case bytecode::Section::kResource: 54 return "Resource (5)"; 55 case bytecode::Section::kResourceOffset: 56 return "ResourceOffset (6)"; 57 case bytecode::Section::kDialectVersions: 58 return "DialectVersions (7)"; 59 default: 60 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 61 } 62 } 63 64 /// Returns true if the given top-level section ID is optional. 65 static bool isSectionOptional(bytecode::Section::ID sectionID) { 66 switch (sectionID) { 67 case bytecode::Section::kString: 68 case bytecode::Section::kDialect: 69 case bytecode::Section::kAttrType: 70 case bytecode::Section::kAttrTypeOffset: 71 case bytecode::Section::kIR: 72 return false; 73 case bytecode::Section::kResource: 74 case bytecode::Section::kResourceOffset: 75 case bytecode::Section::kDialectVersions: 76 return true; 77 default: 78 llvm_unreachable("unknown section ID"); 79 } 80 } 81 82 //===----------------------------------------------------------------------===// 83 // EncodingReader 84 //===----------------------------------------------------------------------===// 85 86 namespace { 87 class EncodingReader { 88 public: 89 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 90 : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} 91 explicit EncodingReader(StringRef contents, Location fileLoc) 92 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 93 contents.size()}, 94 fileLoc) {} 95 96 /// Returns true if the entire section has been read. 97 bool empty() const { return dataIt == dataEnd; } 98 99 /// Returns the remaining size of the bytecode. 100 size_t size() const { return dataEnd - dataIt; } 101 102 /// Align the current reader position to the specified alignment. 103 LogicalResult alignTo(unsigned alignment) { 104 if (!llvm::isPowerOf2_32(alignment)) 105 return emitError("expected alignment to be a power-of-two"); 106 107 // Shift the reader position to the next alignment boundary. 108 while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { 109 uint8_t padding; 110 if (failed(parseByte(padding))) 111 return failure(); 112 if (padding != bytecode::kAlignmentByte) { 113 return emitError("expected alignment byte (0xCB), but got: '0x" + 114 llvm::utohexstr(padding) + "'"); 115 } 116 } 117 118 // Ensure the data iterator is now aligned. This case is unlikely because we 119 // *just* went through the effort to align the data iterator. 120 if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) { 121 return emitError("expected data iterator aligned to ", alignment, 122 ", but got pointer: '0x" + 123 llvm::utohexstr((uintptr_t)dataIt) + "'"); 124 } 125 126 return success(); 127 } 128 129 /// Emit an error using the given arguments. 130 template <typename... Args> 131 InFlightDiagnostic emitError(Args &&...args) const { 132 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 133 } 134 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 135 136 /// Parse a single byte from the stream. 137 template <typename T> 138 LogicalResult parseByte(T &value) { 139 if (empty()) 140 return emitError("attempting to parse a byte at the end of the bytecode"); 141 value = static_cast<T>(*dataIt++); 142 return success(); 143 } 144 /// Parse a range of bytes of 'length' into the given result. 145 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 146 if (length > size()) { 147 return emitError("attempting to parse ", length, " bytes when only ", 148 size(), " remain"); 149 } 150 result = {dataIt, length}; 151 dataIt += length; 152 return success(); 153 } 154 /// Parse a range of bytes of 'length' into the given result, which can be 155 /// assumed to be large enough to hold `length`. 156 LogicalResult parseBytes(size_t length, uint8_t *result) { 157 if (length > size()) { 158 return emitError("attempting to parse ", length, " bytes when only ", 159 size(), " remain"); 160 } 161 memcpy(result, dataIt, length); 162 dataIt += length; 163 return success(); 164 } 165 166 /// Parse an aligned blob of data, where the alignment was encoded alongside 167 /// the data. 168 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 169 uint64_t &alignment) { 170 uint64_t dataSize; 171 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 172 failed(alignTo(alignment))) 173 return failure(); 174 return parseBytes(dataSize, data); 175 } 176 177 /// Parse a variable length encoded integer from the byte stream. The first 178 /// encoded byte contains a prefix in the low bits indicating the encoded 179 /// length of the value. This length prefix is a bit sequence of '0's followed 180 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 181 /// (not including the prefix byte). All remaining bits in the first byte, 182 /// along with all of the bits in additional bytes, provide the value of the 183 /// integer encoded in little-endian order. 184 LogicalResult parseVarInt(uint64_t &result) { 185 // Parse the first byte of the encoding, which contains the length prefix. 186 if (failed(parseByte(result))) 187 return failure(); 188 189 // Handle the overwhelmingly common case where the value is stored in a 190 // single byte. In this case, the first bit is the `1` marker bit. 191 if (LLVM_LIKELY(result & 1)) { 192 result >>= 1; 193 return success(); 194 } 195 196 // Handle the overwhelming uncommon case where the value required all 8 197 // bytes (i.e. a really really big number). In this case, the marker byte is 198 // all zeros: `00000000`. 199 if (LLVM_UNLIKELY(result == 0)) 200 return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result)); 201 return parseMultiByteVarInt(result); 202 } 203 204 /// Parse a signed variable length encoded integer from the byte stream. A 205 /// signed varint is encoded as a normal varint with zigzag encoding applied, 206 /// i.e. the low bit of the value is used to indicate the sign. 207 LogicalResult parseSignedVarInt(uint64_t &result) { 208 if (failed(parseVarInt(result))) 209 return failure(); 210 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 211 result = (result >> 1) ^ (~(result & 1) + 1); 212 return success(); 213 } 214 215 /// Parse a variable length encoded integer whose low bit is used to encode an 216 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 217 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 218 if (failed(parseVarInt(result))) 219 return failure(); 220 flag = result & 1; 221 result >>= 1; 222 return success(); 223 } 224 225 /// Skip the first `length` bytes within the reader. 226 LogicalResult skipBytes(size_t length) { 227 if (length > size()) { 228 return emitError("attempting to skip ", length, " bytes when only ", 229 size(), " remain"); 230 } 231 dataIt += length; 232 return success(); 233 } 234 235 /// Parse a null-terminated string into `result` (without including the NUL 236 /// terminator). 237 LogicalResult parseNullTerminatedString(StringRef &result) { 238 const char *startIt = (const char *)dataIt; 239 const char *nulIt = (const char *)memchr(startIt, 0, size()); 240 if (!nulIt) 241 return emitError( 242 "malformed null-terminated string, no null character found"); 243 244 result = StringRef(startIt, nulIt - startIt); 245 dataIt = (const uint8_t *)nulIt + 1; 246 return success(); 247 } 248 249 /// Parse a section header, placing the kind of section in `sectionID` and the 250 /// contents of the section in `sectionData`. 251 LogicalResult parseSection(bytecode::Section::ID §ionID, 252 ArrayRef<uint8_t> §ionData) { 253 uint8_t sectionIDAndHasAlignment; 254 uint64_t length; 255 if (failed(parseByte(sectionIDAndHasAlignment)) || 256 failed(parseVarInt(length))) 257 return failure(); 258 259 // Extract the section ID and whether the section is aligned. The high bit 260 // of the ID is the alignment flag. 261 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 262 0b01111111); 263 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 264 265 // Check that the section is actually valid before trying to process its 266 // data. 267 if (sectionID >= bytecode::Section::kNumSections) 268 return emitError("invalid section ID: ", unsigned(sectionID)); 269 270 // Process the section alignment if present. 271 if (hasAlignment) { 272 uint64_t alignment; 273 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 274 return failure(); 275 } 276 277 // Parse the actual section data. 278 return parseBytes(static_cast<size_t>(length), sectionData); 279 } 280 281 Location getLoc() const { return fileLoc; } 282 283 private: 284 /// Parse a variable length encoded integer from the byte stream. This method 285 /// is a fallback when the number of bytes used to encode the value is greater 286 /// than 1, but less than the max (9). The provided `result` value can be 287 /// assumed to already contain the first byte of the value. 288 /// NOTE: This method is marked noinline to avoid pessimizing the common case 289 /// of single byte encoding. 290 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 291 // Count the number of trailing zeros in the marker byte, this indicates the 292 // number of trailing bytes that are part of the value. We use `uint32_t` 293 // here because we only care about the first byte, and so that be actually 294 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 295 // implementation). 296 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 297 assert(numBytes > 0 && numBytes <= 7 && 298 "unexpected number of trailing zeros in varint encoding"); 299 300 // Parse in the remaining bytes of the value. 301 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1))) 302 return failure(); 303 304 // Shift out the low-order bits that were used to mark how the value was 305 // encoded. 306 result >>= (numBytes + 1); 307 return success(); 308 } 309 310 /// The current data iterator, and an iterator to the end of the buffer. 311 const uint8_t *dataIt, *dataEnd; 312 313 /// A location for the bytecode used to report errors. 314 Location fileLoc; 315 }; 316 } // namespace 317 318 /// Resolve an index into the given entry list. `entry` may either be a 319 /// reference, in which case it is assigned to the corresponding value in 320 /// `entries`, or a pointer, in which case it is assigned to the address of the 321 /// element in `entries`. 322 template <typename RangeT, typename T> 323 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 324 uint64_t index, T &entry, 325 StringRef entryStr) { 326 if (index >= entries.size()) 327 return reader.emitError("invalid ", entryStr, " index: ", index); 328 329 // If the provided entry is a pointer, resolve to the address of the entry. 330 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 331 entry = entries[index]; 332 else 333 entry = &entries[index]; 334 return success(); 335 } 336 337 /// Parse and resolve an index into the given entry list. 338 template <typename RangeT, typename T> 339 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 340 T &entry, StringRef entryStr) { 341 uint64_t entryIdx; 342 if (failed(reader.parseVarInt(entryIdx))) 343 return failure(); 344 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // StringSectionReader 349 //===----------------------------------------------------------------------===// 350 351 namespace { 352 /// This class is used to read references to the string section from the 353 /// bytecode. 354 class StringSectionReader { 355 public: 356 /// Initialize the string section reader with the given section data. 357 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 358 359 /// Parse a shared string from the string section. The shared string is 360 /// encoded using an index to a corresponding string in the string section. 361 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 362 return parseEntry(reader, strings, result, "string"); 363 } 364 365 /// Parse a shared string from the string section. The shared string is 366 /// encoded using an index to a corresponding string in the string section. 367 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 368 StringRef &result) { 369 return resolveEntry(reader, strings, index, result, "string"); 370 } 371 372 private: 373 /// The table of strings referenced within the bytecode file. 374 SmallVector<StringRef> strings; 375 }; 376 } // namespace 377 378 LogicalResult StringSectionReader::initialize(Location fileLoc, 379 ArrayRef<uint8_t> sectionData) { 380 EncodingReader stringReader(sectionData, fileLoc); 381 382 // Parse the number of strings in the section. 383 uint64_t numStrings; 384 if (failed(stringReader.parseVarInt(numStrings))) 385 return failure(); 386 strings.resize(numStrings); 387 388 // Parse each of the strings. The sizes of the strings are encoded in reverse 389 // order, so that's the order we populate the table. 390 size_t stringDataEndOffset = sectionData.size(); 391 for (StringRef &string : llvm::reverse(strings)) { 392 uint64_t stringSize; 393 if (failed(stringReader.parseVarInt(stringSize))) 394 return failure(); 395 if (stringDataEndOffset < stringSize) { 396 return stringReader.emitError( 397 "string size exceeds the available data size"); 398 } 399 400 // Extract the string from the data, dropping the null character. 401 size_t stringOffset = stringDataEndOffset - stringSize; 402 string = StringRef( 403 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 404 stringSize - 1); 405 stringDataEndOffset = stringOffset; 406 } 407 408 // Check that the only remaining data was for the strings, i.e. the reader 409 // should be at the same offset as the first string. 410 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 411 return stringReader.emitError("unexpected trailing data between the " 412 "offsets for strings and their data"); 413 } 414 return success(); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // BytecodeDialect 419 //===----------------------------------------------------------------------===// 420 421 namespace { 422 class DialectReader; 423 424 /// This struct represents a dialect entry within the bytecode. 425 struct BytecodeDialect { 426 /// Load the dialect into the provided context if it hasn't been loaded yet. 427 /// Returns failure if the dialect couldn't be loaded *and* the provided 428 /// context does not allow unregistered dialects. The provided reader is used 429 /// for error emission if necessary. 430 LogicalResult load(DialectReader &reader, MLIRContext *ctx); 431 432 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 433 /// only be called after `load`. 434 Dialect *getLoadedDialect() const { 435 assert(dialect && 436 "expected `load` to be invoked before `getLoadedDialect`"); 437 return *dialect; 438 } 439 440 /// The loaded dialect entry. This field is std::nullopt if we haven't 441 /// attempted to load, nullptr if we failed to load, otherwise the loaded 442 /// dialect. 443 std::optional<Dialect *> dialect; 444 445 /// The bytecode interface of the dialect, or nullptr if the dialect does not 446 /// implement the bytecode interface. This field should only be checked if the 447 /// `dialect` field is not std::nullopt. 448 const BytecodeDialectInterface *interface = nullptr; 449 450 /// The name of the dialect. 451 StringRef name; 452 453 /// A buffer containing the encoding of the dialect version parsed. 454 ArrayRef<uint8_t> versionBuffer; 455 456 /// Lazy loaded dialect version from the handle above. 457 std::unique_ptr<DialectVersion> loadedVersion; 458 }; 459 460 /// This struct represents an operation name entry within the bytecode. 461 struct BytecodeOperationName { 462 BytecodeOperationName(BytecodeDialect *dialect, StringRef name) 463 : dialect(dialect), name(name) {} 464 465 /// The loaded operation name, or std::nullopt if it hasn't been processed 466 /// yet. 467 std::optional<OperationName> opName; 468 469 /// The dialect that owns this operation name. 470 BytecodeDialect *dialect; 471 472 /// The name of the operation, without the dialect prefix. 473 StringRef name; 474 }; 475 } // namespace 476 477 /// Parse a single dialect group encoded in the byte stream. 478 static LogicalResult parseDialectGrouping( 479 EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, 480 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 481 // Parse the dialect and the number of entries in the group. 482 BytecodeDialect *dialect; 483 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 484 return failure(); 485 uint64_t numEntries; 486 if (failed(reader.parseVarInt(numEntries))) 487 return failure(); 488 489 for (uint64_t i = 0; i < numEntries; ++i) 490 if (failed(entryCallback(dialect))) 491 return failure(); 492 return success(); 493 } 494 495 //===----------------------------------------------------------------------===// 496 // ResourceSectionReader 497 //===----------------------------------------------------------------------===// 498 499 namespace { 500 /// This class is used to read the resource section from the bytecode. 501 class ResourceSectionReader { 502 public: 503 /// Initialize the resource section reader with the given section data. 504 LogicalResult 505 initialize(Location fileLoc, const ParserConfig &config, 506 MutableArrayRef<BytecodeDialect> dialects, 507 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 508 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 509 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 510 511 /// Parse a dialect resource handle from the resource section. 512 LogicalResult parseResourceHandle(EncodingReader &reader, 513 AsmDialectResourceHandle &result) { 514 return parseEntry(reader, dialectResources, result, "resource handle"); 515 } 516 517 private: 518 /// The table of dialect resources within the bytecode file. 519 SmallVector<AsmDialectResourceHandle> dialectResources; 520 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 521 }; 522 523 class ParsedResourceEntry : public AsmParsedResourceEntry { 524 public: 525 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 526 EncodingReader &reader, StringSectionReader &stringReader, 527 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 528 : key(key), kind(kind), reader(reader), stringReader(stringReader), 529 bufferOwnerRef(bufferOwnerRef) {} 530 ~ParsedResourceEntry() override = default; 531 532 StringRef getKey() const final { return key; } 533 534 InFlightDiagnostic emitError() const final { return reader.emitError(); } 535 536 AsmResourceEntryKind getKind() const final { return kind; } 537 538 FailureOr<bool> parseAsBool() const final { 539 if (kind != AsmResourceEntryKind::Bool) 540 return emitError() << "expected a bool resource entry, but found a " 541 << toString(kind) << " entry instead"; 542 543 bool value; 544 if (failed(reader.parseByte(value))) 545 return failure(); 546 return value; 547 } 548 FailureOr<std::string> parseAsString() const final { 549 if (kind != AsmResourceEntryKind::String) 550 return emitError() << "expected a string resource entry, but found a " 551 << toString(kind) << " entry instead"; 552 553 StringRef string; 554 if (failed(stringReader.parseString(reader, string))) 555 return failure(); 556 return string.str(); 557 } 558 559 FailureOr<AsmResourceBlob> 560 parseAsBlob(BlobAllocatorFn allocator) const final { 561 if (kind != AsmResourceEntryKind::Blob) 562 return emitError() << "expected a blob resource entry, but found a " 563 << toString(kind) << " entry instead"; 564 565 ArrayRef<uint8_t> data; 566 uint64_t alignment; 567 if (failed(reader.parseBlobAndAlignment(data, alignment))) 568 return failure(); 569 570 // If we have an extendable reference to the buffer owner, we don't need to 571 // allocate a new buffer for the data, and can use the data directly. 572 if (bufferOwnerRef) { 573 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 574 data.size()); 575 576 // Allocate an unmanager buffer which captures a reference to the owner. 577 // For now we just mark this as immutable, but in the future we should 578 // explore marking this as mutable when desired. 579 return UnmanagedAsmResourceBlob::allocateWithAlign( 580 charData, alignment, 581 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 582 } 583 584 // Allocate memory for the blob using the provided allocator and copy the 585 // data into it. 586 AsmResourceBlob blob = allocator(data.size(), alignment); 587 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 588 blob.isMutable() && 589 "blob allocator did not return a properly aligned address"); 590 memcpy(blob.getMutableData().data(), data.data(), data.size()); 591 return blob; 592 } 593 594 private: 595 StringRef key; 596 AsmResourceEntryKind kind; 597 EncodingReader &reader; 598 StringSectionReader &stringReader; 599 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 600 }; 601 } // namespace 602 603 template <typename T> 604 static LogicalResult 605 parseResourceGroup(Location fileLoc, bool allowEmpty, 606 EncodingReader &offsetReader, EncodingReader &resourceReader, 607 StringSectionReader &stringReader, T *handler, 608 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 609 function_ref<StringRef(StringRef)> remapKey = {}, 610 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 611 uint64_t numResources; 612 if (failed(offsetReader.parseVarInt(numResources))) 613 return failure(); 614 615 for (uint64_t i = 0; i < numResources; ++i) { 616 StringRef key; 617 AsmResourceEntryKind kind; 618 uint64_t resourceOffset; 619 ArrayRef<uint8_t> data; 620 if (failed(stringReader.parseString(offsetReader, key)) || 621 failed(offsetReader.parseVarInt(resourceOffset)) || 622 failed(offsetReader.parseByte(kind)) || 623 failed(resourceReader.parseBytes(resourceOffset, data))) 624 return failure(); 625 626 // Process the resource key. 627 if ((processKeyFn && failed(processKeyFn(key)))) 628 return failure(); 629 630 // If the resource data is empty and we allow it, don't error out when 631 // parsing below, just skip it. 632 if (allowEmpty && data.empty()) 633 continue; 634 635 // Ignore the entry if we don't have a valid handler. 636 if (!handler) 637 continue; 638 639 // Otherwise, parse the resource value. 640 EncodingReader entryReader(data, fileLoc); 641 key = remapKey(key); 642 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 643 bufferOwnerRef); 644 if (failed(handler->parseResource(entry))) 645 return failure(); 646 if (!entryReader.empty()) { 647 return entryReader.emitError( 648 "unexpected trailing bytes in resource entry '", key, "'"); 649 } 650 } 651 return success(); 652 } 653 654 LogicalResult ResourceSectionReader::initialize( 655 Location fileLoc, const ParserConfig &config, 656 MutableArrayRef<BytecodeDialect> dialects, 657 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 658 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 659 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 660 EncodingReader resourceReader(sectionData, fileLoc); 661 EncodingReader offsetReader(offsetSectionData, fileLoc); 662 663 // Read the number of external resource providers. 664 uint64_t numExternalResourceGroups; 665 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 666 return failure(); 667 668 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 669 // provides most of the arguments. 670 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 671 function_ref<LogicalResult(StringRef)> keyFn = {}) { 672 auto resolveKey = [&](StringRef key) -> StringRef { 673 auto it = dialectResourceHandleRenamingMap.find(key); 674 if (it == dialectResourceHandleRenamingMap.end()) 675 return ""; 676 return it->second; 677 }; 678 679 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 680 stringReader, handler, bufferOwnerRef, resolveKey, 681 keyFn); 682 }; 683 684 // Read the external resources from the bytecode. 685 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 686 StringRef key; 687 if (failed(stringReader.parseString(offsetReader, key))) 688 return failure(); 689 690 // Get the handler for these resources. 691 // TODO: Should we require handling external resources in some scenarios? 692 AsmResourceParser *handler = config.getResourceParser(key); 693 if (!handler) { 694 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 695 << "'"; 696 } 697 698 if (failed(parseGroup(handler))) 699 return failure(); 700 } 701 702 // Read the dialect resources from the bytecode. 703 MLIRContext *ctx = fileLoc->getContext(); 704 while (!offsetReader.empty()) { 705 BytecodeDialect *dialect; 706 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 707 failed(dialect->load(dialectReader, ctx))) 708 return failure(); 709 Dialect *loadedDialect = dialect->getLoadedDialect(); 710 if (!loadedDialect) { 711 return resourceReader.emitError() 712 << "dialect '" << dialect->name << "' is unknown"; 713 } 714 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 715 if (!handler) { 716 return resourceReader.emitError() 717 << "unexpected resources for dialect '" << dialect->name << "'"; 718 } 719 720 // Ensure that each resource is declared before being processed. 721 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 722 FailureOr<AsmDialectResourceHandle> handle = 723 handler->declareResource(key); 724 if (failed(handle)) { 725 return resourceReader.emitError() 726 << "unknown 'resource' key '" << key << "' for dialect '" 727 << dialect->name << "'"; 728 } 729 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 730 dialectResources.push_back(*handle); 731 return success(); 732 }; 733 734 // Parse the resources for this dialect. We allow empty resources because we 735 // just treat these as declarations. 736 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 737 return failure(); 738 } 739 740 return success(); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // Attribute/Type Reader 745 //===----------------------------------------------------------------------===// 746 747 namespace { 748 /// This class provides support for reading attribute and type entries from the 749 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 750 /// this reader to manage when to actually parse them from the bytecode. 751 class AttrTypeReader { 752 /// This class represents a single attribute or type entry. 753 template <typename T> 754 struct Entry { 755 /// The entry, or null if it hasn't been resolved yet. 756 T entry = {}; 757 /// The parent dialect of this entry. 758 BytecodeDialect *dialect = nullptr; 759 /// A flag indicating if the entry was encoded using a custom encoding, 760 /// instead of using the textual assembly format. 761 bool hasCustomEncoding = false; 762 /// The raw data of this entry in the bytecode. 763 ArrayRef<uint8_t> data; 764 }; 765 using AttrEntry = Entry<Attribute>; 766 using TypeEntry = Entry<Type>; 767 768 public: 769 AttrTypeReader(StringSectionReader &stringReader, 770 ResourceSectionReader &resourceReader, Location fileLoc) 771 : stringReader(stringReader), resourceReader(resourceReader), 772 fileLoc(fileLoc) {} 773 774 /// Initialize the attribute and type information within the reader. 775 LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, 776 ArrayRef<uint8_t> sectionData, 777 ArrayRef<uint8_t> offsetSectionData); 778 779 /// Resolve the attribute or type at the given index. Returns nullptr on 780 /// failure. 781 Attribute resolveAttribute(size_t index) { 782 return resolveEntry(attributes, index, "Attribute"); 783 } 784 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 785 786 /// Parse a reference to an attribute or type using the given reader. 787 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 788 uint64_t attrIdx; 789 if (failed(reader.parseVarInt(attrIdx))) 790 return failure(); 791 result = resolveAttribute(attrIdx); 792 return success(!!result); 793 } 794 LogicalResult parseType(EncodingReader &reader, Type &result) { 795 uint64_t typeIdx; 796 if (failed(reader.parseVarInt(typeIdx))) 797 return failure(); 798 result = resolveType(typeIdx); 799 return success(!!result); 800 } 801 802 template <typename T> 803 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 804 Attribute baseResult; 805 if (failed(parseAttribute(reader, baseResult))) 806 return failure(); 807 if ((result = dyn_cast<T>(baseResult))) 808 return success(); 809 return reader.emitError("expected attribute of type: ", 810 llvm::getTypeName<T>(), ", but got: ", baseResult); 811 } 812 813 private: 814 /// Resolve the given entry at `index`. 815 template <typename T> 816 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 817 StringRef entryType); 818 819 /// Parse an entry using the given reader that was encoded using the textual 820 /// assembly format. 821 template <typename T> 822 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 823 StringRef entryType); 824 825 /// Parse an entry using the given reader that was encoded using a custom 826 /// bytecode format. 827 template <typename T> 828 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 829 StringRef entryType); 830 831 /// The string section reader used to resolve string references when parsing 832 /// custom encoded attribute/type entries. 833 StringSectionReader &stringReader; 834 835 /// The resource section reader used to resolve resource references when 836 /// parsing custom encoded attribute/type entries. 837 ResourceSectionReader &resourceReader; 838 839 /// The set of attribute and type entries. 840 SmallVector<AttrEntry> attributes; 841 SmallVector<TypeEntry> types; 842 843 /// A location used for error emission. 844 Location fileLoc; 845 }; 846 847 class DialectReader : public DialectBytecodeReader { 848 public: 849 DialectReader(AttrTypeReader &attrTypeReader, 850 StringSectionReader &stringReader, 851 ResourceSectionReader &resourceReader, EncodingReader &reader) 852 : attrTypeReader(attrTypeReader), stringReader(stringReader), 853 resourceReader(resourceReader), reader(reader) {} 854 855 InFlightDiagnostic emitError(const Twine &msg) override { 856 return reader.emitError(msg); 857 } 858 859 DialectReader withEncodingReader(EncodingReader &encReader) { 860 return DialectReader(attrTypeReader, stringReader, resourceReader, 861 encReader); 862 } 863 864 Location getLoc() const { return reader.getLoc(); } 865 866 //===--------------------------------------------------------------------===// 867 // IR 868 //===--------------------------------------------------------------------===// 869 870 LogicalResult readAttribute(Attribute &result) override { 871 return attrTypeReader.parseAttribute(reader, result); 872 } 873 874 LogicalResult readType(Type &result) override { 875 return attrTypeReader.parseType(reader, result); 876 } 877 878 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 879 AsmDialectResourceHandle handle; 880 if (failed(resourceReader.parseResourceHandle(reader, handle))) 881 return failure(); 882 return handle; 883 } 884 885 //===--------------------------------------------------------------------===// 886 // Primitives 887 //===--------------------------------------------------------------------===// 888 889 LogicalResult readVarInt(uint64_t &result) override { 890 return reader.parseVarInt(result); 891 } 892 893 LogicalResult readSignedVarInt(int64_t &result) override { 894 uint64_t unsignedResult; 895 if (failed(reader.parseSignedVarInt(unsignedResult))) 896 return failure(); 897 result = static_cast<int64_t>(unsignedResult); 898 return success(); 899 } 900 901 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 902 // Small values are encoded using a single byte. 903 if (bitWidth <= 8) { 904 uint8_t value; 905 if (failed(reader.parseByte(value))) 906 return failure(); 907 return APInt(bitWidth, value); 908 } 909 910 // Large values up to 64 bits are encoded using a single varint. 911 if (bitWidth <= 64) { 912 uint64_t value; 913 if (failed(reader.parseSignedVarInt(value))) 914 return failure(); 915 return APInt(bitWidth, value); 916 } 917 918 // Otherwise, for really big values we encode the array of active words in 919 // the value. 920 uint64_t numActiveWords; 921 if (failed(reader.parseVarInt(numActiveWords))) 922 return failure(); 923 SmallVector<uint64_t, 4> words(numActiveWords); 924 for (uint64_t i = 0; i < numActiveWords; ++i) 925 if (failed(reader.parseSignedVarInt(words[i]))) 926 return failure(); 927 return APInt(bitWidth, words); 928 } 929 930 FailureOr<APFloat> 931 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 932 FailureOr<APInt> intVal = 933 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 934 if (failed(intVal)) 935 return failure(); 936 return APFloat(semantics, *intVal); 937 } 938 939 LogicalResult readString(StringRef &result) override { 940 return stringReader.parseString(reader, result); 941 } 942 943 LogicalResult readBlob(ArrayRef<char> &result) override { 944 uint64_t dataSize; 945 ArrayRef<uint8_t> data; 946 if (failed(reader.parseVarInt(dataSize)) || 947 failed(reader.parseBytes(dataSize, data))) 948 return failure(); 949 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 950 data.size()); 951 return success(); 952 } 953 954 private: 955 AttrTypeReader &attrTypeReader; 956 StringSectionReader &stringReader; 957 ResourceSectionReader &resourceReader; 958 EncodingReader &reader; 959 }; 960 } // namespace 961 962 LogicalResult 963 AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, 964 ArrayRef<uint8_t> sectionData, 965 ArrayRef<uint8_t> offsetSectionData) { 966 EncodingReader offsetReader(offsetSectionData, fileLoc); 967 968 // Parse the number of attribute and type entries. 969 uint64_t numAttributes, numTypes; 970 if (failed(offsetReader.parseVarInt(numAttributes)) || 971 failed(offsetReader.parseVarInt(numTypes))) 972 return failure(); 973 attributes.resize(numAttributes); 974 types.resize(numTypes); 975 976 // A functor used to accumulate the offsets for the entries in the given 977 // range. 978 uint64_t currentOffset = 0; 979 auto parseEntries = [&](auto &&range) { 980 size_t currentIndex = 0, endIndex = range.size(); 981 982 // Parse an individual entry. 983 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 984 auto &entry = range[currentIndex++]; 985 986 uint64_t entrySize; 987 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 988 entry.hasCustomEncoding))) 989 return failure(); 990 991 // Verify that the offset is actually valid. 992 if (currentOffset + entrySize > sectionData.size()) { 993 return offsetReader.emitError( 994 "Attribute or Type entry offset points past the end of section"); 995 } 996 997 entry.data = sectionData.slice(currentOffset, entrySize); 998 entry.dialect = dialect; 999 currentOffset += entrySize; 1000 return success(); 1001 }; 1002 while (currentIndex != endIndex) 1003 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1004 return failure(); 1005 return success(); 1006 }; 1007 1008 // Process each of the attributes, and then the types. 1009 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1010 return failure(); 1011 1012 // Ensure that we read everything from the section. 1013 if (!offsetReader.empty()) { 1014 return offsetReader.emitError( 1015 "unexpected trailing data in the Attribute/Type offset section"); 1016 } 1017 return success(); 1018 } 1019 1020 template <typename T> 1021 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1022 StringRef entryType) { 1023 if (index >= entries.size()) { 1024 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1025 return {}; 1026 } 1027 1028 // If the entry has already been resolved, there is nothing left to do. 1029 Entry<T> &entry = entries[index]; 1030 if (entry.entry) 1031 return entry.entry; 1032 1033 // Parse the entry. 1034 EncodingReader reader(entry.data, fileLoc); 1035 1036 // Parse based on how the entry was encoded. 1037 if (entry.hasCustomEncoding) { 1038 if (failed(parseCustomEntry(entry, reader, entryType))) 1039 return T(); 1040 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1041 return T(); 1042 } 1043 1044 if (!reader.empty()) { 1045 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1046 return T(); 1047 } 1048 return entry.entry; 1049 } 1050 1051 template <typename T> 1052 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1053 StringRef entryType) { 1054 StringRef asmStr; 1055 if (failed(reader.parseNullTerminatedString(asmStr))) 1056 return failure(); 1057 1058 // Invoke the MLIR assembly parser to parse the entry text. 1059 size_t numRead = 0; 1060 MLIRContext *context = fileLoc->getContext(); 1061 if constexpr (std::is_same_v<T, Type>) 1062 result = 1063 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1064 else 1065 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1066 /*isKnownNullTerminated=*/true); 1067 if (!result) 1068 return failure(); 1069 1070 // Ensure there weren't dangling characters after the entry. 1071 if (numRead != asmStr.size()) { 1072 return reader.emitError("trailing characters found after ", entryType, 1073 " assembly format: ", asmStr.drop_front(numRead)); 1074 } 1075 return success(); 1076 } 1077 1078 template <typename T> 1079 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1080 EncodingReader &reader, 1081 StringRef entryType) { 1082 DialectReader dialectReader(*this, stringReader, resourceReader, reader); 1083 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1084 return failure(); 1085 // Ensure that the dialect implements the bytecode interface. 1086 if (!entry.dialect->interface) { 1087 return reader.emitError("dialect '", entry.dialect->name, 1088 "' does not implement the bytecode interface"); 1089 } 1090 1091 // Ask the dialect to parse the entry. If the dialect is versioned, parse 1092 // using the versioned encoding readers. 1093 if (entry.dialect->loadedVersion.get()) { 1094 if constexpr (std::is_same_v<T, Type>) 1095 entry.entry = entry.dialect->interface->readType( 1096 dialectReader, *entry.dialect->loadedVersion); 1097 else 1098 entry.entry = entry.dialect->interface->readAttribute( 1099 dialectReader, *entry.dialect->loadedVersion); 1100 1101 } else { 1102 if constexpr (std::is_same_v<T, Type>) 1103 entry.entry = entry.dialect->interface->readType(dialectReader); 1104 else 1105 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1106 } 1107 return success(!!entry.entry); 1108 } 1109 1110 //===----------------------------------------------------------------------===// 1111 // Bytecode Reader 1112 //===----------------------------------------------------------------------===// 1113 1114 /// This class is used to read a bytecode buffer and translate it into MLIR. 1115 class mlir::BytecodeReader::Impl { 1116 struct RegionReadState; 1117 using LazyLoadableOpsInfo = 1118 std::list<std::pair<Operation *, RegionReadState>>; 1119 using LazyLoadableOpsMap = 1120 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1121 1122 public: 1123 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1124 llvm::MemoryBufferRef buffer, 1125 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1126 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1127 attrTypeReader(stringReader, resourceReader, fileLoc), 1128 // Use the builtin unrealized conversion cast operation to represent 1129 // forward references to values that aren't yet defined. 1130 forwardRefOpState(UnknownLoc::get(config.getContext()), 1131 "builtin.unrealized_conversion_cast", ValueRange(), 1132 NoneType::get(config.getContext())), 1133 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1134 1135 /// Read the bytecode defined within `buffer` into the given block. 1136 LogicalResult read(Block *block, 1137 llvm::function_ref<bool(Operation *)> lazyOps); 1138 1139 /// Return the number of ops that haven't been materialized yet. 1140 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1141 1142 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1143 1144 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1145 /// newly found lazy operation. 1146 LogicalResult 1147 materialize(Operation *op, 1148 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1149 this->lazyOpsCallback = lazyOpsCallback; 1150 auto resetlazyOpsCallback = 1151 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1152 auto it = lazyLoadableOpsMap.find(op); 1153 assert(it != lazyLoadableOpsMap.end() && 1154 "materialize called on non-materializable op"); 1155 return materialize(it); 1156 } 1157 1158 /// Materialize all operations. 1159 LogicalResult materializeAll() { 1160 while (!lazyLoadableOpsMap.empty()) { 1161 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1162 return failure(); 1163 } 1164 return success(); 1165 } 1166 1167 /// Finalize the lazy-loading by calling back with every op that hasn't been 1168 /// materialized to let the client decide if the op should be deleted or 1169 /// materialized. The op is materialized if the callback returns true, deleted 1170 /// otherwise. 1171 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1172 while (!lazyLoadableOps.empty()) { 1173 Operation *op = lazyLoadableOps.begin()->first; 1174 if (shouldMaterialize(op)) { 1175 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1176 return failure(); 1177 continue; 1178 } 1179 op->dropAllReferences(); 1180 op->erase(); 1181 lazyLoadableOps.pop_front(); 1182 lazyLoadableOpsMap.erase(op); 1183 } 1184 return success(); 1185 } 1186 1187 private: 1188 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1189 assert(it != lazyLoadableOpsMap.end() && 1190 "materialize called on non-materializable op"); 1191 valueScopes.emplace_back(); 1192 std::vector<RegionReadState> regionStack; 1193 regionStack.push_back(std::move(it->getSecond()->second)); 1194 lazyLoadableOps.erase(it->getSecond()); 1195 lazyLoadableOpsMap.erase(it); 1196 auto result = parseRegions(regionStack, regionStack.back()); 1197 assert(regionStack.empty()); 1198 return result; 1199 } 1200 1201 /// Return the context for this config. 1202 MLIRContext *getContext() const { return config.getContext(); } 1203 1204 /// Parse the bytecode version. 1205 LogicalResult parseVersion(EncodingReader &reader); 1206 1207 //===--------------------------------------------------------------------===// 1208 // Dialect Section 1209 1210 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1211 1212 /// Parse an operation name reference using the given reader. 1213 FailureOr<OperationName> parseOpName(EncodingReader &reader); 1214 1215 //===--------------------------------------------------------------------===// 1216 // Attribute/Type Section 1217 1218 /// Parse an attribute or type using the given reader. 1219 template <typename T> 1220 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1221 return attrTypeReader.parseAttribute(reader, result); 1222 } 1223 LogicalResult parseType(EncodingReader &reader, Type &result) { 1224 return attrTypeReader.parseType(reader, result); 1225 } 1226 1227 //===--------------------------------------------------------------------===// 1228 // Resource Section 1229 1230 LogicalResult 1231 parseResourceSection(EncodingReader &reader, 1232 std::optional<ArrayRef<uint8_t>> resourceData, 1233 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1234 1235 //===--------------------------------------------------------------------===// 1236 // IR Section 1237 1238 /// This struct represents the current read state of a range of regions. This 1239 /// struct is used to enable iterative parsing of regions. 1240 struct RegionReadState { 1241 RegionReadState(Operation *op, EncodingReader *reader, 1242 bool isIsolatedFromAbove) 1243 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1244 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1245 bool isIsolatedFromAbove) 1246 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1247 isIsolatedFromAbove(isIsolatedFromAbove) {} 1248 1249 /// The current regions being read. 1250 MutableArrayRef<Region>::iterator curRegion, endRegion; 1251 /// This is the reader to use for this region, this pointer is pointing to 1252 /// the parent region reader unless the current region is IsolatedFromAbove, 1253 /// in which case the pointer is pointing to the `owningReader` which is a 1254 /// section dedicated to the current region. 1255 EncodingReader *reader; 1256 std::unique_ptr<EncodingReader> owningReader; 1257 1258 /// The number of values defined immediately within this region. 1259 unsigned numValues = 0; 1260 1261 /// The current blocks of the region being read. 1262 SmallVector<Block *> curBlocks; 1263 Region::iterator curBlock = {}; 1264 1265 /// The number of operations remaining to be read from the current block 1266 /// being read. 1267 uint64_t numOpsRemaining = 0; 1268 1269 /// A flag indicating if the regions being read are isolated from above. 1270 bool isIsolatedFromAbove = false; 1271 }; 1272 1273 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1274 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1275 RegionReadState &readState); 1276 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1277 RegionReadState &readState, 1278 bool &isIsolatedFromAbove); 1279 1280 LogicalResult parseRegion(RegionReadState &readState); 1281 LogicalResult parseBlockHeader(EncodingReader &reader, 1282 RegionReadState &readState); 1283 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1284 1285 //===--------------------------------------------------------------------===// 1286 // Value Processing 1287 1288 /// Parse an operand reference using the given reader. Returns nullptr in the 1289 /// case of failure. 1290 Value parseOperand(EncodingReader &reader); 1291 1292 /// Sequentially define the given value range. 1293 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1294 1295 /// Create a value to use for a forward reference. 1296 Value createForwardRef(); 1297 1298 //===--------------------------------------------------------------------===// 1299 // Use-list order helpers 1300 1301 /// This struct is a simple storage that contains information required to 1302 /// reorder the use-list of a value with respect to the pre-order traversal 1303 /// ordering. 1304 struct UseListOrderStorage { 1305 UseListOrderStorage(bool isIndexPairEncoding, 1306 SmallVector<unsigned, 4> &&indices) 1307 : indices(std::move(indices)), 1308 isIndexPairEncoding(isIndexPairEncoding){}; 1309 /// The vector containing the information required to reorder the 1310 /// use-list of a value. 1311 SmallVector<unsigned, 4> indices; 1312 1313 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1314 /// indexing, such as `dst = order[src]`. 1315 bool isIndexPairEncoding; 1316 }; 1317 1318 /// Parse use-list order from bytecode for a range of values if available. The 1319 /// range is expected to be either a block argument or an op result range. On 1320 /// success, return a map of the position in the range and the use-list order 1321 /// encoding. The function assumes to know the size of the range it is 1322 /// processing. 1323 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1324 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1325 uint64_t rangeSize); 1326 1327 /// Shuffle the use-chain according to the order parsed. 1328 LogicalResult sortUseListOrder(Value value); 1329 1330 /// Recursively visit all the values defined within topLevelOp and sort the 1331 /// use-list orders according to the indices parsed. 1332 LogicalResult processUseLists(Operation *topLevelOp); 1333 1334 //===--------------------------------------------------------------------===// 1335 // Fields 1336 1337 /// This class represents a single value scope, in which a value scope is 1338 /// delimited by isolated from above regions. 1339 struct ValueScope { 1340 /// Push a new region state onto this scope, reserving enough values for 1341 /// those defined within the current region of the provided state. 1342 void push(RegionReadState &readState) { 1343 nextValueIDs.push_back(values.size()); 1344 values.resize(values.size() + readState.numValues); 1345 } 1346 1347 /// Pop the values defined for the current region within the provided region 1348 /// state. 1349 void pop(RegionReadState &readState) { 1350 values.resize(values.size() - readState.numValues); 1351 nextValueIDs.pop_back(); 1352 } 1353 1354 /// The set of values defined in this scope. 1355 std::vector<Value> values; 1356 1357 /// The ID for the next defined value for each region current being 1358 /// processed in this scope. 1359 SmallVector<unsigned, 4> nextValueIDs; 1360 }; 1361 1362 /// The configuration of the parser. 1363 const ParserConfig &config; 1364 1365 /// A location to use when emitting errors. 1366 Location fileLoc; 1367 1368 /// Flag that indicates if lazyloading is enabled. 1369 bool lazyLoading; 1370 1371 /// Keep track of operations that have been lazy loaded (their regions haven't 1372 /// been materialized), along with the `RegionReadState` that allows to 1373 /// lazy-load the regions nested under the operation. 1374 LazyLoadableOpsInfo lazyLoadableOps; 1375 LazyLoadableOpsMap lazyLoadableOpsMap; 1376 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1377 1378 /// The reader used to process attribute and types within the bytecode. 1379 AttrTypeReader attrTypeReader; 1380 1381 /// The version of the bytecode being read. 1382 uint64_t version = 0; 1383 1384 /// The producer of the bytecode being read. 1385 StringRef producer; 1386 1387 /// The table of IR units referenced within the bytecode file. 1388 SmallVector<BytecodeDialect> dialects; 1389 SmallVector<BytecodeOperationName> opNames; 1390 1391 /// The reader used to process resources within the bytecode. 1392 ResourceSectionReader resourceReader; 1393 1394 /// Worklist of values with custom use-list orders to process before the end 1395 /// of the parsing. 1396 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1397 1398 /// The table of strings referenced within the bytecode file. 1399 StringSectionReader stringReader; 1400 1401 /// The current set of available IR value scopes. 1402 std::vector<ValueScope> valueScopes; 1403 1404 /// The global pre-order operation ordering. 1405 DenseMap<Operation *, unsigned> operationIDs; 1406 1407 /// A block containing the set of operations defined to create forward 1408 /// references. 1409 Block forwardRefOps; 1410 1411 /// A block containing previously created, and no longer used, forward 1412 /// reference operations. 1413 Block openForwardRefOps; 1414 1415 /// An operation state used when instantiating forward references. 1416 OperationState forwardRefOpState; 1417 1418 /// Reference to the input buffer. 1419 llvm::MemoryBufferRef buffer; 1420 1421 /// The optional owning source manager, which when present may be used to 1422 /// extend the lifetime of the input buffer. 1423 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1424 }; 1425 1426 LogicalResult BytecodeReader::Impl::read( 1427 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1428 EncodingReader reader(buffer.getBuffer(), fileLoc); 1429 this->lazyOpsCallback = lazyOpsCallback; 1430 auto resetlazyOpsCallback = 1431 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1432 1433 // Skip over the bytecode header, this should have already been checked. 1434 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1435 return failure(); 1436 // Parse the bytecode version and producer. 1437 if (failed(parseVersion(reader)) || 1438 failed(reader.parseNullTerminatedString(producer))) 1439 return failure(); 1440 1441 // Add a diagnostic handler that attaches a note that includes the original 1442 // producer of the bytecode. 1443 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1444 diag.attachNote() << "in bytecode version " << version 1445 << " produced by: " << producer; 1446 return failure(); 1447 }); 1448 1449 // Parse the raw data for each of the top-level sections of the bytecode. 1450 std::optional<ArrayRef<uint8_t>> 1451 sectionDatas[bytecode::Section::kNumSections]; 1452 while (!reader.empty()) { 1453 // Read the next section from the bytecode. 1454 bytecode::Section::ID sectionID; 1455 ArrayRef<uint8_t> sectionData; 1456 if (failed(reader.parseSection(sectionID, sectionData))) 1457 return failure(); 1458 1459 // Check for duplicate sections, we only expect one instance of each. 1460 if (sectionDatas[sectionID]) { 1461 return reader.emitError("duplicate top-level section: ", 1462 ::toString(sectionID)); 1463 } 1464 sectionDatas[sectionID] = sectionData; 1465 } 1466 // Check that all of the required sections were found. 1467 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1468 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1469 if (!sectionDatas[i] && !isSectionOptional(sectionID)) { 1470 return reader.emitError("missing data for top-level section: ", 1471 ::toString(sectionID)); 1472 } 1473 } 1474 1475 // Process the string section first. 1476 if (failed(stringReader.initialize( 1477 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1478 return failure(); 1479 1480 // Process the dialect section. 1481 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1482 return failure(); 1483 1484 // Process the resource section if present. 1485 if (failed(parseResourceSection( 1486 reader, sectionDatas[bytecode::Section::kResource], 1487 sectionDatas[bytecode::Section::kResourceOffset]))) 1488 return failure(); 1489 1490 // Process the attribute and type section. 1491 if (failed(attrTypeReader.initialize( 1492 dialects, *sectionDatas[bytecode::Section::kAttrType], 1493 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1494 return failure(); 1495 1496 // Finally, process the IR section. 1497 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1498 } 1499 1500 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1501 if (failed(reader.parseVarInt(version))) 1502 return failure(); 1503 1504 // Validate the bytecode version. 1505 uint64_t currentVersion = bytecode::kVersion; 1506 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1507 if (version < minSupportedVersion) { 1508 return reader.emitError("bytecode version ", version, 1509 " is older than the current version of ", 1510 currentVersion, ", and upgrade is not supported"); 1511 } 1512 if (version > currentVersion) { 1513 return reader.emitError("bytecode version ", version, 1514 " is newer than the current version ", 1515 currentVersion); 1516 } 1517 // Override any request to lazy-load if the bytecode version is too old. 1518 if (version < 2) 1519 lazyLoading = false; 1520 return success(); 1521 } 1522 1523 //===----------------------------------------------------------------------===// 1524 // Dialect Section 1525 1526 LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { 1527 if (dialect) 1528 return success(); 1529 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1530 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1531 return reader.emitError("dialect '") 1532 << name 1533 << "' is unknown. If this is intended, please call " 1534 "allowUnregisteredDialects() on the MLIRContext, or use " 1535 "-allow-unregistered-dialect with the MLIR tool used."; 1536 } 1537 dialect = loadedDialect; 1538 1539 // If the dialect was actually loaded, check to see if it has a bytecode 1540 // interface. 1541 if (loadedDialect) 1542 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1543 if (!versionBuffer.empty()) { 1544 if (!interface) 1545 return reader.emitError("dialect '") 1546 << name 1547 << "' does not implement the bytecode interface, " 1548 "but found a version entry"; 1549 EncodingReader encReader(versionBuffer, reader.getLoc()); 1550 DialectReader versionReader = reader.withEncodingReader(encReader); 1551 loadedVersion = interface->readVersion(versionReader); 1552 if (!loadedVersion) 1553 return failure(); 1554 } 1555 return success(); 1556 } 1557 1558 LogicalResult 1559 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1560 EncodingReader sectionReader(sectionData, fileLoc); 1561 1562 // Parse the number of dialects in the section. 1563 uint64_t numDialects; 1564 if (failed(sectionReader.parseVarInt(numDialects))) 1565 return failure(); 1566 dialects.resize(numDialects); 1567 1568 // Parse each of the dialects. 1569 for (uint64_t i = 0; i < numDialects; ++i) { 1570 /// Before version 1, there wasn't any versioning available for dialects, 1571 /// and the entryIdx represent the string itself. 1572 if (version == 0) { 1573 if (failed(stringReader.parseString(sectionReader, dialects[i].name))) 1574 return failure(); 1575 continue; 1576 } 1577 // Parse ID representing dialect and version. 1578 uint64_t dialectNameIdx; 1579 bool versionAvailable; 1580 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1581 versionAvailable))) 1582 return failure(); 1583 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1584 dialects[i].name))) 1585 return failure(); 1586 if (versionAvailable) { 1587 bytecode::Section::ID sectionID; 1588 if (failed( 1589 sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) 1590 return failure(); 1591 if (sectionID != bytecode::Section::kDialectVersions) { 1592 emitError(fileLoc, "expected dialect version section"); 1593 return failure(); 1594 } 1595 } 1596 } 1597 1598 // Parse the operation names, which are grouped by dialect. 1599 auto parseOpName = [&](BytecodeDialect *dialect) { 1600 StringRef opName; 1601 if (failed(stringReader.parseString(sectionReader, opName))) 1602 return failure(); 1603 opNames.emplace_back(dialect, opName); 1604 return success(); 1605 }; 1606 // Avoid re-allocation in bytecode version > 3 where the number of ops are 1607 // known. 1608 if (version > 3) { 1609 uint64_t numOps; 1610 if (failed(sectionReader.parseVarInt(numOps))) 1611 return failure(); 1612 opNames.reserve(numOps); 1613 } 1614 while (!sectionReader.empty()) 1615 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1616 return failure(); 1617 return success(); 1618 } 1619 1620 FailureOr<OperationName> 1621 BytecodeReader::Impl::parseOpName(EncodingReader &reader) { 1622 BytecodeOperationName *opName = nullptr; 1623 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1624 return failure(); 1625 1626 // Check to see if this operation name has already been resolved. If we 1627 // haven't, load the dialect and build the operation name. 1628 if (!opName->opName) { 1629 // Load the dialect and its version. 1630 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1631 reader); 1632 if (failed(opName->dialect->load(dialectReader, getContext()))) 1633 return failure(); 1634 // If the opName is empty, this is because we use to accept names such as 1635 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1636 // format anymore but for now we'll be backward compatible. This can only 1637 // happen with unregistered dialects. 1638 if (opName->name.empty()) { 1639 if (opName->dialect->getLoadedDialect()) 1640 return emitError(fileLoc) << "has an empty opname for dialect '" 1641 << opName->dialect->name << "'\n"; 1642 1643 opName->opName.emplace(opName->dialect->name, getContext()); 1644 } else { 1645 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1646 getContext()); 1647 } 1648 } 1649 return *opName->opName; 1650 } 1651 1652 //===----------------------------------------------------------------------===// 1653 // Resource Section 1654 1655 LogicalResult BytecodeReader::Impl::parseResourceSection( 1656 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1657 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1658 // Ensure both sections are either present or not. 1659 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1660 if (resourceOffsetData) 1661 return emitError(fileLoc, "unexpected resource offset section when " 1662 "resource section is not present"); 1663 return emitError( 1664 fileLoc, 1665 "expected resource offset section when resource section is present"); 1666 } 1667 1668 // If the resource sections are absent, there is nothing to do. 1669 if (!resourceData) 1670 return success(); 1671 1672 // Initialize the resource reader with the resource sections. 1673 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1674 reader); 1675 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1676 *resourceData, *resourceOffsetData, 1677 dialectReader, bufferOwnerRef); 1678 } 1679 1680 //===----------------------------------------------------------------------===// 1681 // UseListOrder Helpers 1682 1683 FailureOr<BytecodeReader::Impl::UseListMapT> 1684 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1685 uint64_t numResults) { 1686 BytecodeReader::Impl::UseListMapT map; 1687 uint64_t numValuesToRead = 1; 1688 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1689 return failure(); 1690 1691 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1692 uint64_t resultIdx = 0; 1693 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1694 return failure(); 1695 1696 uint64_t numValues; 1697 bool indexPairEncoding; 1698 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1699 return failure(); 1700 1701 SmallVector<unsigned, 4> useListOrders; 1702 for (size_t idx = 0; idx < numValues; idx++) { 1703 uint64_t index; 1704 if (failed(reader.parseVarInt(index))) 1705 return failure(); 1706 useListOrders.push_back(index); 1707 } 1708 1709 // Store in a map the result index 1710 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1711 std::move(useListOrders))); 1712 } 1713 1714 return map; 1715 } 1716 1717 /// Sorts each use according to the order specified in the use-list parsed. If 1718 /// the custom use-list is not found, this means that the order needs to be 1719 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1720 /// on the same operation, the order will follow the reverse operand number 1721 /// ordering. 1722 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1723 // Early return for trivial use-lists. 1724 if (value.use_empty() || value.hasOneUse()) 1725 return success(); 1726 1727 bool hasIncomingOrder = 1728 valueToUseListMap.contains(value.getAsOpaquePointer()); 1729 1730 // Compute the current order of the use-list with respect to the global 1731 // ordering. Detect if the order is already sorted while doing so. 1732 bool alreadySorted = true; 1733 auto &firstUse = *value.use_begin(); 1734 uint64_t prevID = 1735 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1736 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1737 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1738 uint64_t currentID = bytecode::getUseID( 1739 item.value(), operationIDs.at(item.value().getOwner())); 1740 alreadySorted &= prevID > currentID; 1741 currentOrder.push_back({item.index(), currentID}); 1742 prevID = currentID; 1743 } 1744 1745 // If the order is already sorted, and there wasn't a custom order to apply 1746 // from the bytecode file, we are done. 1747 if (alreadySorted && !hasIncomingOrder) 1748 return success(); 1749 1750 // If not already sorted, sort the indices of the current order by descending 1751 // useIDs. 1752 if (!alreadySorted) 1753 std::sort( 1754 currentOrder.begin(), currentOrder.end(), 1755 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1756 1757 if (!hasIncomingOrder) { 1758 // If the bytecode file did not contain any custom use-list order, it means 1759 // that the order was descending useID. Hence, shuffle by the first index 1760 // of the `currentOrder` pair. 1761 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1762 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1763 value.shuffleUseList(shuffle); 1764 return success(); 1765 } 1766 1767 // Pull the custom order info from the map. 1768 UseListOrderStorage customOrder = 1769 valueToUseListMap.at(value.getAsOpaquePointer()); 1770 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1771 uint64_t numUses = 1772 std::distance(value.getUses().begin(), value.getUses().end()); 1773 1774 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1775 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1776 // as identity, and then apply the mapping encoded in the indices. 1777 if (customOrder.isIndexPairEncoding) { 1778 // Return failure if the number of indices was not representing pairs. 1779 if (shuffle.size() & 1) 1780 return failure(); 1781 1782 SmallVector<unsigned, 4> newShuffle(numUses); 1783 size_t idx = 0; 1784 std::iota(newShuffle.begin(), newShuffle.end(), idx); 1785 for (idx = 0; idx < shuffle.size(); idx += 2) 1786 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 1787 1788 shuffle = std::move(newShuffle); 1789 } 1790 1791 // Make sure that the indices represent a valid mapping. That is, the sum of 1792 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 1793 // duplicates are allowed in the list. 1794 DenseSet<unsigned> set; 1795 uint64_t accumulator = 0; 1796 for (const auto &elem : shuffle) { 1797 if (set.contains(elem)) 1798 return failure(); 1799 accumulator += elem; 1800 set.insert(elem); 1801 } 1802 if (numUses != shuffle.size() || 1803 accumulator != (((numUses - 1) * numUses) >> 1)) 1804 return failure(); 1805 1806 // Apply the current ordering map onto the shuffle vector to get the final 1807 // use-list sorting indices before shuffling. 1808 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 1809 currentOrder, [&](auto item) { return shuffle[item.first]; })); 1810 value.shuffleUseList(shuffle); 1811 return success(); 1812 } 1813 1814 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 1815 // Precompute operation IDs according to the pre-order walk of the IR. We 1816 // can't do this while parsing since parseRegions ordering is not strictly 1817 // equal to the pre-order walk. 1818 unsigned operationID = 0; 1819 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 1820 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 1821 1822 auto blockWalk = topLevelOp->walk([this](Block *block) { 1823 for (auto arg : block->getArguments()) 1824 if (failed(sortUseListOrder(arg))) 1825 return WalkResult::interrupt(); 1826 return WalkResult::advance(); 1827 }); 1828 1829 auto resultWalk = topLevelOp->walk([this](Operation *op) { 1830 for (auto result : op->getResults()) 1831 if (failed(sortUseListOrder(result))) 1832 return WalkResult::interrupt(); 1833 return WalkResult::advance(); 1834 }); 1835 1836 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 1837 } 1838 1839 //===----------------------------------------------------------------------===// 1840 // IR Section 1841 1842 LogicalResult 1843 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 1844 Block *block) { 1845 EncodingReader reader(sectionData, fileLoc); 1846 1847 // A stack of operation regions currently being read from the bytecode. 1848 std::vector<RegionReadState> regionStack; 1849 1850 // Parse the top-level block using a temporary module operation. 1851 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 1852 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 1853 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 1854 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 1855 if (failed(parseBlockHeader(reader, regionStack.back()))) 1856 return failure(); 1857 valueScopes.emplace_back(); 1858 valueScopes.back().push(regionStack.back()); 1859 1860 // Iteratively parse regions until everything has been resolved. 1861 while (!regionStack.empty()) 1862 if (failed(parseRegions(regionStack, regionStack.back()))) 1863 return failure(); 1864 if (!forwardRefOps.empty()) { 1865 return reader.emitError( 1866 "not all forward unresolved forward operand references"); 1867 } 1868 1869 // Sort use-lists according to what specified in bytecode. 1870 if (failed(processUseLists(*moduleOp))) 1871 return reader.emitError( 1872 "parsed use-list orders were invalid and could not be applied"); 1873 1874 // Resolve dialect version. 1875 for (const BytecodeDialect &byteCodeDialect : dialects) { 1876 // Parsing is complete, give an opportunity to each dialect to visit the 1877 // IR and perform upgrades. 1878 if (!byteCodeDialect.loadedVersion) 1879 continue; 1880 if (byteCodeDialect.interface && 1881 failed(byteCodeDialect.interface->upgradeFromVersion( 1882 *moduleOp, *byteCodeDialect.loadedVersion))) 1883 return failure(); 1884 } 1885 1886 // Verify that the parsed operations are valid. 1887 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 1888 return failure(); 1889 1890 // Splice the parsed operations over to the provided top-level block. 1891 auto &parsedOps = moduleOp->getBody()->getOperations(); 1892 auto &destOps = block->getOperations(); 1893 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 1894 return success(); 1895 } 1896 1897 LogicalResult 1898 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 1899 RegionReadState &readState) { 1900 // Process regions, blocks, and operations until the end or if a nested 1901 // region is encountered. In this case we push a new state in regionStack and 1902 // return, the processing of the current region will resume afterward. 1903 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 1904 // If the current block hasn't been setup yet, parse the header for this 1905 // region. The current block is already setup when this function was 1906 // interrupted to recurse down in a nested region and we resume the current 1907 // block after processing the nested region. 1908 if (readState.curBlock == Region::iterator()) { 1909 if (failed(parseRegion(readState))) 1910 return failure(); 1911 1912 // If the region is empty, there is nothing to more to do. 1913 if (readState.curRegion->empty()) 1914 continue; 1915 } 1916 1917 // Parse the blocks within the region. 1918 EncodingReader &reader = *readState.reader; 1919 do { 1920 while (readState.numOpsRemaining--) { 1921 // Read in the next operation. We don't read its regions directly, we 1922 // handle those afterwards as necessary. 1923 bool isIsolatedFromAbove = false; 1924 FailureOr<Operation *> op = 1925 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 1926 if (failed(op)) 1927 return failure(); 1928 1929 // If the op has regions, add it to the stack for processing and return: 1930 // we stop the processing of the current region and resume it after the 1931 // inner one is completed. Unless LazyLoading is activated in which case 1932 // nested region parsing is delayed. 1933 if ((*op)->getNumRegions()) { 1934 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 1935 1936 // Isolated regions are encoded as a section in version 2 and above. 1937 if (version >= 2 && isIsolatedFromAbove) { 1938 bytecode::Section::ID sectionID; 1939 ArrayRef<uint8_t> sectionData; 1940 if (failed(reader.parseSection(sectionID, sectionData))) 1941 return failure(); 1942 if (sectionID != bytecode::Section::kIR) 1943 return emitError(fileLoc, "expected IR section for region"); 1944 childState.owningReader = 1945 std::make_unique<EncodingReader>(sectionData, fileLoc); 1946 childState.reader = childState.owningReader.get(); 1947 } 1948 1949 if (lazyLoading) { 1950 // If the user has a callback set, they have the opportunity 1951 // to control lazyloading as we go. 1952 if (!lazyOpsCallback || !lazyOpsCallback(*op)) { 1953 lazyLoadableOps.push_back( 1954 std::make_pair(*op, std::move(childState))); 1955 lazyLoadableOpsMap.try_emplace(*op, 1956 std::prev(lazyLoadableOps.end())); 1957 continue; 1958 } 1959 } 1960 regionStack.push_back(std::move(childState)); 1961 1962 // If the op is isolated from above, push a new value scope. 1963 if (isIsolatedFromAbove) 1964 valueScopes.emplace_back(); 1965 return success(); 1966 } 1967 } 1968 1969 // Move to the next block of the region. 1970 if (++readState.curBlock == readState.curRegion->end()) 1971 break; 1972 if (failed(parseBlockHeader(reader, readState))) 1973 return failure(); 1974 } while (true); 1975 1976 // Reset the current block and any values reserved for this region. 1977 readState.curBlock = {}; 1978 valueScopes.back().pop(readState); 1979 } 1980 1981 // When the regions have been fully parsed, pop them off of the read stack. If 1982 // the regions were isolated from above, we also pop the last value scope. 1983 if (readState.isIsolatedFromAbove) { 1984 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 1985 valueScopes.pop_back(); 1986 } 1987 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 1988 regionStack.pop_back(); 1989 return success(); 1990 } 1991 1992 FailureOr<Operation *> 1993 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 1994 RegionReadState &readState, 1995 bool &isIsolatedFromAbove) { 1996 // Parse the name of the operation. 1997 FailureOr<OperationName> opName = parseOpName(reader); 1998 if (failed(opName)) 1999 return failure(); 2000 2001 // Parse the operation mask, which indicates which components of the operation 2002 // are present. 2003 uint8_t opMask; 2004 if (failed(reader.parseByte(opMask))) 2005 return failure(); 2006 2007 /// Parse the location. 2008 LocationAttr opLoc; 2009 if (failed(parseAttribute(reader, opLoc))) 2010 return failure(); 2011 2012 // With the location and name resolved, we can start building the operation 2013 // state. 2014 OperationState opState(opLoc, *opName); 2015 2016 // Parse the attributes of the operation. 2017 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2018 DictionaryAttr dictAttr; 2019 if (failed(parseAttribute(reader, dictAttr))) 2020 return failure(); 2021 opState.attributes = dictAttr; 2022 } 2023 2024 /// Parse the results of the operation. 2025 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2026 uint64_t numResults; 2027 if (failed(reader.parseVarInt(numResults))) 2028 return failure(); 2029 opState.types.resize(numResults); 2030 for (int i = 0, e = numResults; i < e; ++i) 2031 if (failed(parseType(reader, opState.types[i]))) 2032 return failure(); 2033 } 2034 2035 /// Parse the operands of the operation. 2036 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2037 uint64_t numOperands; 2038 if (failed(reader.parseVarInt(numOperands))) 2039 return failure(); 2040 opState.operands.resize(numOperands); 2041 for (int i = 0, e = numOperands; i < e; ++i) 2042 if (!(opState.operands[i] = parseOperand(reader))) 2043 return failure(); 2044 } 2045 2046 /// Parse the successors of the operation. 2047 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2048 uint64_t numSuccs; 2049 if (failed(reader.parseVarInt(numSuccs))) 2050 return failure(); 2051 opState.successors.resize(numSuccs); 2052 for (int i = 0, e = numSuccs; i < e; ++i) { 2053 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2054 "successor"))) 2055 return failure(); 2056 } 2057 } 2058 2059 /// Parse the use-list orders for the results of the operation. Use-list 2060 /// orders are available since version 3 of the bytecode. 2061 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2062 if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2063 size_t numResults = opState.types.size(); 2064 auto parseResult = parseUseListOrderForRange(reader, numResults); 2065 if (failed(parseResult)) 2066 return failure(); 2067 resultIdxToUseListMap = std::move(*parseResult); 2068 } 2069 2070 /// Parse the regions of the operation. 2071 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2072 uint64_t numRegions; 2073 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2074 return failure(); 2075 2076 opState.regions.reserve(numRegions); 2077 for (int i = 0, e = numRegions; i < e; ++i) 2078 opState.regions.push_back(std::make_unique<Region>()); 2079 } 2080 2081 // Create the operation at the back of the current block. 2082 Operation *op = Operation::create(opState); 2083 readState.curBlock->push_back(op); 2084 2085 // If the operation had results, update the value references. 2086 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2087 return failure(); 2088 2089 /// Store a map for every value that received a custom use-list order from the 2090 /// bytecode file. 2091 if (resultIdxToUseListMap.has_value()) { 2092 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2093 if (resultIdxToUseListMap->contains(idx)) { 2094 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2095 resultIdxToUseListMap->at(idx)); 2096 } 2097 } 2098 } 2099 return op; 2100 } 2101 2102 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2103 EncodingReader &reader = *readState.reader; 2104 2105 // Parse the number of blocks in the region. 2106 uint64_t numBlocks; 2107 if (failed(reader.parseVarInt(numBlocks))) 2108 return failure(); 2109 2110 // If the region is empty, there is nothing else to do. 2111 if (numBlocks == 0) 2112 return success(); 2113 2114 // Parse the number of values defined in this region. 2115 uint64_t numValues; 2116 if (failed(reader.parseVarInt(numValues))) 2117 return failure(); 2118 readState.numValues = numValues; 2119 2120 // Create the blocks within this region. We do this before processing so that 2121 // we can rely on the blocks existing when creating operations. 2122 readState.curBlocks.clear(); 2123 readState.curBlocks.reserve(numBlocks); 2124 for (uint64_t i = 0; i < numBlocks; ++i) { 2125 readState.curBlocks.push_back(new Block()); 2126 readState.curRegion->push_back(readState.curBlocks.back()); 2127 } 2128 2129 // Prepare the current value scope for this region. 2130 valueScopes.back().push(readState); 2131 2132 // Parse the entry block of the region. 2133 readState.curBlock = readState.curRegion->begin(); 2134 return parseBlockHeader(reader, readState); 2135 } 2136 2137 LogicalResult 2138 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2139 RegionReadState &readState) { 2140 bool hasArgs; 2141 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2142 return failure(); 2143 2144 // Parse the arguments of the block. 2145 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2146 return failure(); 2147 2148 // Uselist orders are available since version 3 of the bytecode. 2149 if (version < 3) 2150 return success(); 2151 2152 uint8_t hasUseListOrders = 0; 2153 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2154 return failure(); 2155 2156 if (!hasUseListOrders) 2157 return success(); 2158 2159 Block &blk = *readState.curBlock; 2160 auto argIdxToUseListMap = 2161 parseUseListOrderForRange(reader, blk.getNumArguments()); 2162 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2163 return failure(); 2164 2165 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2166 if (argIdxToUseListMap->contains(idx)) 2167 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2168 argIdxToUseListMap->at(idx)); 2169 2170 // We don't parse the operations of the block here, that's done elsewhere. 2171 return success(); 2172 } 2173 2174 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2175 Block *block) { 2176 // Parse the value ID for the first argument, and the number of arguments. 2177 uint64_t numArgs; 2178 if (failed(reader.parseVarInt(numArgs))) 2179 return failure(); 2180 2181 SmallVector<Type> argTypes; 2182 SmallVector<Location> argLocs; 2183 argTypes.reserve(numArgs); 2184 argLocs.reserve(numArgs); 2185 2186 Location unknownLoc = UnknownLoc::get(config.getContext()); 2187 while (numArgs--) { 2188 Type argType; 2189 LocationAttr argLoc = unknownLoc; 2190 if (version > 3) { 2191 // Parse the type with hasLoc flag to determine if it has type. 2192 uint64_t typeIdx; 2193 bool hasLoc; 2194 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2195 !(argType = attrTypeReader.resolveType(typeIdx))) 2196 return failure(); 2197 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2198 return failure(); 2199 } else { 2200 // All args has type and location. 2201 if (failed(parseType(reader, argType)) || 2202 failed(parseAttribute(reader, argLoc))) 2203 return failure(); 2204 } 2205 argTypes.push_back(argType); 2206 argLocs.push_back(argLoc); 2207 } 2208 block->addArguments(argTypes, argLocs); 2209 return defineValues(reader, block->getArguments()); 2210 } 2211 2212 //===----------------------------------------------------------------------===// 2213 // Value Processing 2214 2215 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2216 std::vector<Value> &values = valueScopes.back().values; 2217 Value *value = nullptr; 2218 if (failed(parseEntry(reader, values, value, "value"))) 2219 return Value(); 2220 2221 // Create a new forward reference if necessary. 2222 if (!*value) 2223 *value = createForwardRef(); 2224 return *value; 2225 } 2226 2227 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2228 ValueRange newValues) { 2229 ValueScope &valueScope = valueScopes.back(); 2230 std::vector<Value> &values = valueScope.values; 2231 2232 unsigned &valueID = valueScope.nextValueIDs.back(); 2233 unsigned valueIDEnd = valueID + newValues.size(); 2234 if (valueIDEnd > values.size()) { 2235 return reader.emitError( 2236 "value index range was outside of the expected range for " 2237 "the parent region, got [", 2238 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2239 values.size() - 1); 2240 } 2241 2242 // Assign the values and update any forward references. 2243 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2244 Value newValue = newValues[i]; 2245 2246 // Check to see if a definition for this value already exists. 2247 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2248 Operation *forwardRefOp = oldValue.getDefiningOp(); 2249 2250 // Assert that this is a forward reference operation. Given how we compute 2251 // definition ids (incrementally as we parse), it shouldn't be possible 2252 // for the value to be defined any other way. 2253 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2254 "value index was already defined?"); 2255 2256 oldValue.replaceAllUsesWith(newValue); 2257 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2258 } 2259 } 2260 return success(); 2261 } 2262 2263 Value BytecodeReader::Impl::createForwardRef() { 2264 // Check for an avaliable existing operation to use. Otherwise, create a new 2265 // fake operation to use for the reference. 2266 if (!openForwardRefOps.empty()) { 2267 Operation *op = &openForwardRefOps.back(); 2268 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2269 } else { 2270 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2271 } 2272 return forwardRefOps.back().getResult(0); 2273 } 2274 2275 //===----------------------------------------------------------------------===// 2276 // Entry Points 2277 //===----------------------------------------------------------------------===// 2278 2279 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2280 2281 BytecodeReader::BytecodeReader( 2282 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2283 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2284 Location sourceFileLoc = 2285 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2286 /*line=*/0, /*column=*/0); 2287 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2288 bufferOwnerRef); 2289 } 2290 2291 LogicalResult BytecodeReader::readTopLevel( 2292 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2293 return impl->read(block, lazyOpsCallback); 2294 } 2295 2296 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2297 return impl->getNumOpsToMaterialize(); 2298 } 2299 2300 bool BytecodeReader::isMaterializable(Operation *op) { 2301 return impl->isMaterializable(op); 2302 } 2303 2304 LogicalResult BytecodeReader::materialize( 2305 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2306 return impl->materialize(op, lazyOpsCallback); 2307 } 2308 2309 LogicalResult 2310 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2311 return impl->finalize(shouldMaterialize); 2312 } 2313 2314 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2315 return buffer.getBuffer().startswith("ML\xefR"); 2316 } 2317 2318 /// Read the bytecode from the provided memory buffer reference. 2319 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2320 /// and may be used to extend the lifetime of the buffer. 2321 static LogicalResult 2322 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2323 const ParserConfig &config, 2324 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2325 Location sourceFileLoc = 2326 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2327 /*line=*/0, /*column=*/0); 2328 if (!isBytecode(buffer)) { 2329 return emitError(sourceFileLoc, 2330 "input buffer is not an MLIR bytecode file"); 2331 } 2332 2333 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2334 buffer, bufferOwnerRef); 2335 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2336 } 2337 2338 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2339 const ParserConfig &config) { 2340 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2341 } 2342 LogicalResult 2343 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2344 Block *block, const ParserConfig &config) { 2345 return readBytecodeFileImpl( 2346 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2347 sourceMgr); 2348 } 2349