1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeReader.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/Verifier.h" 19 #include "mlir/IR/Visitors.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/MapVector.h" 24 #include "llvm/ADT/ScopeExit.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/Support/Endian.h" 29 #include "llvm/Support/MemoryBufferRef.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/SourceMgr.h" 32 #include <cstddef> 33 #include <list> 34 #include <memory> 35 #include <numeric> 36 #include <optional> 37 38 #define DEBUG_TYPE "mlir-bytecode-reader" 39 40 using namespace mlir; 41 42 /// Stringify the given section ID. 43 static std::string toString(bytecode::Section::ID sectionID) { 44 switch (sectionID) { 45 case bytecode::Section::kString: 46 return "String (0)"; 47 case bytecode::Section::kDialect: 48 return "Dialect (1)"; 49 case bytecode::Section::kAttrType: 50 return "AttrType (2)"; 51 case bytecode::Section::kAttrTypeOffset: 52 return "AttrTypeOffset (3)"; 53 case bytecode::Section::kIR: 54 return "IR (4)"; 55 case bytecode::Section::kResource: 56 return "Resource (5)"; 57 case bytecode::Section::kResourceOffset: 58 return "ResourceOffset (6)"; 59 case bytecode::Section::kDialectVersions: 60 return "DialectVersions (7)"; 61 case bytecode::Section::kProperties: 62 return "Properties (8)"; 63 default: 64 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 65 } 66 } 67 68 /// Returns true if the given top-level section ID is optional. 69 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 70 switch (sectionID) { 71 case bytecode::Section::kString: 72 case bytecode::Section::kDialect: 73 case bytecode::Section::kAttrType: 74 case bytecode::Section::kAttrTypeOffset: 75 case bytecode::Section::kIR: 76 return false; 77 case bytecode::Section::kResource: 78 case bytecode::Section::kResourceOffset: 79 case bytecode::Section::kDialectVersions: 80 return true; 81 case bytecode::Section::kProperties: 82 return version < bytecode::kNativePropertiesEncoding; 83 default: 84 llvm_unreachable("unknown section ID"); 85 } 86 } 87 88 //===----------------------------------------------------------------------===// 89 // EncodingReader 90 //===----------------------------------------------------------------------===// 91 92 namespace { 93 class EncodingReader { 94 public: 95 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 96 : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} 97 explicit EncodingReader(StringRef contents, Location fileLoc) 98 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 99 contents.size()}, 100 fileLoc) {} 101 102 /// Returns true if the entire section has been read. 103 bool empty() const { return dataIt == dataEnd; } 104 105 /// Returns the remaining size of the bytecode. 106 size_t size() const { return dataEnd - dataIt; } 107 108 /// Align the current reader position to the specified alignment. 109 LogicalResult alignTo(unsigned alignment) { 110 if (!llvm::isPowerOf2_32(alignment)) 111 return emitError("expected alignment to be a power-of-two"); 112 113 // Shift the reader position to the next alignment boundary. 114 while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { 115 uint8_t padding; 116 if (failed(parseByte(padding))) 117 return failure(); 118 if (padding != bytecode::kAlignmentByte) { 119 return emitError("expected alignment byte (0xCB), but got: '0x" + 120 llvm::utohexstr(padding) + "'"); 121 } 122 } 123 124 // Ensure the data iterator is now aligned. This case is unlikely because we 125 // *just* went through the effort to align the data iterator. 126 if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) { 127 return emitError("expected data iterator aligned to ", alignment, 128 ", but got pointer: '0x" + 129 llvm::utohexstr((uintptr_t)dataIt) + "'"); 130 } 131 132 return success(); 133 } 134 135 /// Emit an error using the given arguments. 136 template <typename... Args> 137 InFlightDiagnostic emitError(Args &&...args) const { 138 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 139 } 140 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 141 142 /// Parse a single byte from the stream. 143 template <typename T> 144 LogicalResult parseByte(T &value) { 145 if (empty()) 146 return emitError("attempting to parse a byte at the end of the bytecode"); 147 value = static_cast<T>(*dataIt++); 148 return success(); 149 } 150 /// Parse a range of bytes of 'length' into the given result. 151 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 152 if (length > size()) { 153 return emitError("attempting to parse ", length, " bytes when only ", 154 size(), " remain"); 155 } 156 result = {dataIt, length}; 157 dataIt += length; 158 return success(); 159 } 160 /// Parse a range of bytes of 'length' into the given result, which can be 161 /// assumed to be large enough to hold `length`. 162 LogicalResult parseBytes(size_t length, uint8_t *result) { 163 if (length > size()) { 164 return emitError("attempting to parse ", length, " bytes when only ", 165 size(), " remain"); 166 } 167 memcpy(result, dataIt, length); 168 dataIt += length; 169 return success(); 170 } 171 172 /// Parse an aligned blob of data, where the alignment was encoded alongside 173 /// the data. 174 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 175 uint64_t &alignment) { 176 uint64_t dataSize; 177 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 178 failed(alignTo(alignment))) 179 return failure(); 180 return parseBytes(dataSize, data); 181 } 182 183 /// Parse a variable length encoded integer from the byte stream. The first 184 /// encoded byte contains a prefix in the low bits indicating the encoded 185 /// length of the value. This length prefix is a bit sequence of '0's followed 186 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 187 /// (not including the prefix byte). All remaining bits in the first byte, 188 /// along with all of the bits in additional bytes, provide the value of the 189 /// integer encoded in little-endian order. 190 LogicalResult parseVarInt(uint64_t &result) { 191 // Parse the first byte of the encoding, which contains the length prefix. 192 if (failed(parseByte(result))) 193 return failure(); 194 195 // Handle the overwhelmingly common case where the value is stored in a 196 // single byte. In this case, the first bit is the `1` marker bit. 197 if (LLVM_LIKELY(result & 1)) { 198 result >>= 1; 199 return success(); 200 } 201 202 // Handle the overwhelming uncommon case where the value required all 8 203 // bytes (i.e. a really really big number). In this case, the marker byte is 204 // all zeros: `00000000`. 205 if (LLVM_UNLIKELY(result == 0)) { 206 llvm::support::ulittle64_t resultLE; 207 if (failed(parseBytes(sizeof(resultLE), 208 reinterpret_cast<uint8_t *>(&resultLE)))) 209 return failure(); 210 result = resultLE; 211 return success(); 212 } 213 return parseMultiByteVarInt(result); 214 } 215 216 /// Parse a signed variable length encoded integer from the byte stream. A 217 /// signed varint is encoded as a normal varint with zigzag encoding applied, 218 /// i.e. the low bit of the value is used to indicate the sign. 219 LogicalResult parseSignedVarInt(uint64_t &result) { 220 if (failed(parseVarInt(result))) 221 return failure(); 222 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 223 result = (result >> 1) ^ (~(result & 1) + 1); 224 return success(); 225 } 226 227 /// Parse a variable length encoded integer whose low bit is used to encode an 228 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 229 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 230 if (failed(parseVarInt(result))) 231 return failure(); 232 flag = result & 1; 233 result >>= 1; 234 return success(); 235 } 236 237 /// Skip the first `length` bytes within the reader. 238 LogicalResult skipBytes(size_t length) { 239 if (length > size()) { 240 return emitError("attempting to skip ", length, " bytes when only ", 241 size(), " remain"); 242 } 243 dataIt += length; 244 return success(); 245 } 246 247 /// Parse a null-terminated string into `result` (without including the NUL 248 /// terminator). 249 LogicalResult parseNullTerminatedString(StringRef &result) { 250 const char *startIt = (const char *)dataIt; 251 const char *nulIt = (const char *)memchr(startIt, 0, size()); 252 if (!nulIt) 253 return emitError( 254 "malformed null-terminated string, no null character found"); 255 256 result = StringRef(startIt, nulIt - startIt); 257 dataIt = (const uint8_t *)nulIt + 1; 258 return success(); 259 } 260 261 /// Parse a section header, placing the kind of section in `sectionID` and the 262 /// contents of the section in `sectionData`. 263 LogicalResult parseSection(bytecode::Section::ID §ionID, 264 ArrayRef<uint8_t> §ionData) { 265 uint8_t sectionIDAndHasAlignment; 266 uint64_t length; 267 if (failed(parseByte(sectionIDAndHasAlignment)) || 268 failed(parseVarInt(length))) 269 return failure(); 270 271 // Extract the section ID and whether the section is aligned. The high bit 272 // of the ID is the alignment flag. 273 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 274 0b01111111); 275 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 276 277 // Check that the section is actually valid before trying to process its 278 // data. 279 if (sectionID >= bytecode::Section::kNumSections) 280 return emitError("invalid section ID: ", unsigned(sectionID)); 281 282 // Process the section alignment if present. 283 if (hasAlignment) { 284 uint64_t alignment; 285 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 286 return failure(); 287 } 288 289 // Parse the actual section data. 290 return parseBytes(static_cast<size_t>(length), sectionData); 291 } 292 293 Location getLoc() const { return fileLoc; } 294 295 private: 296 /// Parse a variable length encoded integer from the byte stream. This method 297 /// is a fallback when the number of bytes used to encode the value is greater 298 /// than 1, but less than the max (9). The provided `result` value can be 299 /// assumed to already contain the first byte of the value. 300 /// NOTE: This method is marked noinline to avoid pessimizing the common case 301 /// of single byte encoding. 302 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 303 // Count the number of trailing zeros in the marker byte, this indicates the 304 // number of trailing bytes that are part of the value. We use `uint32_t` 305 // here because we only care about the first byte, and so that be actually 306 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 307 // implementation). 308 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 309 assert(numBytes > 0 && numBytes <= 7 && 310 "unexpected number of trailing zeros in varint encoding"); 311 312 // Parse in the remaining bytes of the value. 313 llvm::support::ulittle64_t resultLE(result); 314 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1))) 315 return failure(); 316 317 // Shift out the low-order bits that were used to mark how the value was 318 // encoded. 319 result = resultLE >> (numBytes + 1); 320 return success(); 321 } 322 323 /// The current data iterator, and an iterator to the end of the buffer. 324 const uint8_t *dataIt, *dataEnd; 325 326 /// A location for the bytecode used to report errors. 327 Location fileLoc; 328 }; 329 } // namespace 330 331 /// Resolve an index into the given entry list. `entry` may either be a 332 /// reference, in which case it is assigned to the corresponding value in 333 /// `entries`, or a pointer, in which case it is assigned to the address of the 334 /// element in `entries`. 335 template <typename RangeT, typename T> 336 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 337 uint64_t index, T &entry, 338 StringRef entryStr) { 339 if (index >= entries.size()) 340 return reader.emitError("invalid ", entryStr, " index: ", index); 341 342 // If the provided entry is a pointer, resolve to the address of the entry. 343 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 344 entry = entries[index]; 345 else 346 entry = &entries[index]; 347 return success(); 348 } 349 350 /// Parse and resolve an index into the given entry list. 351 template <typename RangeT, typename T> 352 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 353 T &entry, StringRef entryStr) { 354 uint64_t entryIdx; 355 if (failed(reader.parseVarInt(entryIdx))) 356 return failure(); 357 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // StringSectionReader 362 //===----------------------------------------------------------------------===// 363 364 namespace { 365 /// This class is used to read references to the string section from the 366 /// bytecode. 367 class StringSectionReader { 368 public: 369 /// Initialize the string section reader with the given section data. 370 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 371 372 /// Parse a shared string from the string section. The shared string is 373 /// encoded using an index to a corresponding string in the string section. 374 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 375 return parseEntry(reader, strings, result, "string"); 376 } 377 378 /// Parse a shared string from the string section. The shared string is 379 /// encoded using an index to a corresponding string in the string section. 380 /// This variant parses a flag compressed with the index. 381 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 382 bool &flag) { 383 uint64_t entryIdx; 384 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 385 return failure(); 386 return parseStringAtIndex(reader, entryIdx, result); 387 } 388 389 /// Parse a shared string from the string section. The shared string is 390 /// encoded using an index to a corresponding string in the string section. 391 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 392 StringRef &result) { 393 return resolveEntry(reader, strings, index, result, "string"); 394 } 395 396 private: 397 /// The table of strings referenced within the bytecode file. 398 SmallVector<StringRef> strings; 399 }; 400 } // namespace 401 402 LogicalResult StringSectionReader::initialize(Location fileLoc, 403 ArrayRef<uint8_t> sectionData) { 404 EncodingReader stringReader(sectionData, fileLoc); 405 406 // Parse the number of strings in the section. 407 uint64_t numStrings; 408 if (failed(stringReader.parseVarInt(numStrings))) 409 return failure(); 410 strings.resize(numStrings); 411 412 // Parse each of the strings. The sizes of the strings are encoded in reverse 413 // order, so that's the order we populate the table. 414 size_t stringDataEndOffset = sectionData.size(); 415 for (StringRef &string : llvm::reverse(strings)) { 416 uint64_t stringSize; 417 if (failed(stringReader.parseVarInt(stringSize))) 418 return failure(); 419 if (stringDataEndOffset < stringSize) { 420 return stringReader.emitError( 421 "string size exceeds the available data size"); 422 } 423 424 // Extract the string from the data, dropping the null character. 425 size_t stringOffset = stringDataEndOffset - stringSize; 426 string = StringRef( 427 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 428 stringSize - 1); 429 stringDataEndOffset = stringOffset; 430 } 431 432 // Check that the only remaining data was for the strings, i.e. the reader 433 // should be at the same offset as the first string. 434 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 435 return stringReader.emitError("unexpected trailing data between the " 436 "offsets for strings and their data"); 437 } 438 return success(); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // BytecodeDialect 443 //===----------------------------------------------------------------------===// 444 445 namespace { 446 class DialectReader; 447 448 /// This struct represents a dialect entry within the bytecode. 449 struct BytecodeDialect { 450 /// Load the dialect into the provided context if it hasn't been loaded yet. 451 /// Returns failure if the dialect couldn't be loaded *and* the provided 452 /// context does not allow unregistered dialects. The provided reader is used 453 /// for error emission if necessary. 454 LogicalResult load(DialectReader &reader, MLIRContext *ctx); 455 456 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 457 /// only be called after `load`. 458 Dialect *getLoadedDialect() const { 459 assert(dialect && 460 "expected `load` to be invoked before `getLoadedDialect`"); 461 return *dialect; 462 } 463 464 /// The loaded dialect entry. This field is std::nullopt if we haven't 465 /// attempted to load, nullptr if we failed to load, otherwise the loaded 466 /// dialect. 467 std::optional<Dialect *> dialect; 468 469 /// The bytecode interface of the dialect, or nullptr if the dialect does not 470 /// implement the bytecode interface. This field should only be checked if the 471 /// `dialect` field is not std::nullopt. 472 const BytecodeDialectInterface *interface = nullptr; 473 474 /// The name of the dialect. 475 StringRef name; 476 477 /// A buffer containing the encoding of the dialect version parsed. 478 ArrayRef<uint8_t> versionBuffer; 479 480 /// Lazy loaded dialect version from the handle above. 481 std::unique_ptr<DialectVersion> loadedVersion; 482 }; 483 484 /// This struct represents an operation name entry within the bytecode. 485 struct BytecodeOperationName { 486 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 487 std::optional<bool> wasRegistered) 488 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 489 490 /// The loaded operation name, or std::nullopt if it hasn't been processed 491 /// yet. 492 std::optional<OperationName> opName; 493 494 /// The dialect that owns this operation name. 495 BytecodeDialect *dialect; 496 497 /// The name of the operation, without the dialect prefix. 498 StringRef name; 499 500 /// Whether this operation was registered when the bytecode was produced. 501 /// This flag is populated when bytecode version >=kNativePropertiesEncoding. 502 std::optional<bool> wasRegistered; 503 }; 504 } // namespace 505 506 /// Parse a single dialect group encoded in the byte stream. 507 static LogicalResult parseDialectGrouping( 508 EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, 509 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 510 // Parse the dialect and the number of entries in the group. 511 BytecodeDialect *dialect; 512 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 513 return failure(); 514 uint64_t numEntries; 515 if (failed(reader.parseVarInt(numEntries))) 516 return failure(); 517 518 for (uint64_t i = 0; i < numEntries; ++i) 519 if (failed(entryCallback(dialect))) 520 return failure(); 521 return success(); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // ResourceSectionReader 526 //===----------------------------------------------------------------------===// 527 528 namespace { 529 /// This class is used to read the resource section from the bytecode. 530 class ResourceSectionReader { 531 public: 532 /// Initialize the resource section reader with the given section data. 533 LogicalResult 534 initialize(Location fileLoc, const ParserConfig &config, 535 MutableArrayRef<BytecodeDialect> dialects, 536 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 537 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 538 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 539 540 /// Parse a dialect resource handle from the resource section. 541 LogicalResult parseResourceHandle(EncodingReader &reader, 542 AsmDialectResourceHandle &result) { 543 return parseEntry(reader, dialectResources, result, "resource handle"); 544 } 545 546 private: 547 /// The table of dialect resources within the bytecode file. 548 SmallVector<AsmDialectResourceHandle> dialectResources; 549 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 550 }; 551 552 class ParsedResourceEntry : public AsmParsedResourceEntry { 553 public: 554 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 555 EncodingReader &reader, StringSectionReader &stringReader, 556 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 557 : key(key), kind(kind), reader(reader), stringReader(stringReader), 558 bufferOwnerRef(bufferOwnerRef) {} 559 ~ParsedResourceEntry() override = default; 560 561 StringRef getKey() const final { return key; } 562 563 InFlightDiagnostic emitError() const final { return reader.emitError(); } 564 565 AsmResourceEntryKind getKind() const final { return kind; } 566 567 FailureOr<bool> parseAsBool() const final { 568 if (kind != AsmResourceEntryKind::Bool) 569 return emitError() << "expected a bool resource entry, but found a " 570 << toString(kind) << " entry instead"; 571 572 bool value; 573 if (failed(reader.parseByte(value))) 574 return failure(); 575 return value; 576 } 577 FailureOr<std::string> parseAsString() const final { 578 if (kind != AsmResourceEntryKind::String) 579 return emitError() << "expected a string resource entry, but found a " 580 << toString(kind) << " entry instead"; 581 582 StringRef string; 583 if (failed(stringReader.parseString(reader, string))) 584 return failure(); 585 return string.str(); 586 } 587 588 FailureOr<AsmResourceBlob> 589 parseAsBlob(BlobAllocatorFn allocator) const final { 590 if (kind != AsmResourceEntryKind::Blob) 591 return emitError() << "expected a blob resource entry, but found a " 592 << toString(kind) << " entry instead"; 593 594 ArrayRef<uint8_t> data; 595 uint64_t alignment; 596 if (failed(reader.parseBlobAndAlignment(data, alignment))) 597 return failure(); 598 599 // If we have an extendable reference to the buffer owner, we don't need to 600 // allocate a new buffer for the data, and can use the data directly. 601 if (bufferOwnerRef) { 602 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 603 data.size()); 604 605 // Allocate an unmanager buffer which captures a reference to the owner. 606 // For now we just mark this as immutable, but in the future we should 607 // explore marking this as mutable when desired. 608 return UnmanagedAsmResourceBlob::allocateWithAlign( 609 charData, alignment, 610 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 611 } 612 613 // Allocate memory for the blob using the provided allocator and copy the 614 // data into it. 615 AsmResourceBlob blob = allocator(data.size(), alignment); 616 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 617 blob.isMutable() && 618 "blob allocator did not return a properly aligned address"); 619 memcpy(blob.getMutableData().data(), data.data(), data.size()); 620 return blob; 621 } 622 623 private: 624 StringRef key; 625 AsmResourceEntryKind kind; 626 EncodingReader &reader; 627 StringSectionReader &stringReader; 628 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 629 }; 630 } // namespace 631 632 template <typename T> 633 static LogicalResult 634 parseResourceGroup(Location fileLoc, bool allowEmpty, 635 EncodingReader &offsetReader, EncodingReader &resourceReader, 636 StringSectionReader &stringReader, T *handler, 637 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 638 function_ref<StringRef(StringRef)> remapKey = {}, 639 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 640 uint64_t numResources; 641 if (failed(offsetReader.parseVarInt(numResources))) 642 return failure(); 643 644 for (uint64_t i = 0; i < numResources; ++i) { 645 StringRef key; 646 AsmResourceEntryKind kind; 647 uint64_t resourceOffset; 648 ArrayRef<uint8_t> data; 649 if (failed(stringReader.parseString(offsetReader, key)) || 650 failed(offsetReader.parseVarInt(resourceOffset)) || 651 failed(offsetReader.parseByte(kind)) || 652 failed(resourceReader.parseBytes(resourceOffset, data))) 653 return failure(); 654 655 // Process the resource key. 656 if ((processKeyFn && failed(processKeyFn(key)))) 657 return failure(); 658 659 // If the resource data is empty and we allow it, don't error out when 660 // parsing below, just skip it. 661 if (allowEmpty && data.empty()) 662 continue; 663 664 // Ignore the entry if we don't have a valid handler. 665 if (!handler) 666 continue; 667 668 // Otherwise, parse the resource value. 669 EncodingReader entryReader(data, fileLoc); 670 key = remapKey(key); 671 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 672 bufferOwnerRef); 673 if (failed(handler->parseResource(entry))) 674 return failure(); 675 if (!entryReader.empty()) { 676 return entryReader.emitError( 677 "unexpected trailing bytes in resource entry '", key, "'"); 678 } 679 } 680 return success(); 681 } 682 683 LogicalResult ResourceSectionReader::initialize( 684 Location fileLoc, const ParserConfig &config, 685 MutableArrayRef<BytecodeDialect> dialects, 686 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 687 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 688 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 689 EncodingReader resourceReader(sectionData, fileLoc); 690 EncodingReader offsetReader(offsetSectionData, fileLoc); 691 692 // Read the number of external resource providers. 693 uint64_t numExternalResourceGroups; 694 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 695 return failure(); 696 697 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 698 // provides most of the arguments. 699 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 700 function_ref<LogicalResult(StringRef)> keyFn = {}) { 701 auto resolveKey = [&](StringRef key) -> StringRef { 702 auto it = dialectResourceHandleRenamingMap.find(key); 703 if (it == dialectResourceHandleRenamingMap.end()) 704 return ""; 705 return it->second; 706 }; 707 708 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 709 stringReader, handler, bufferOwnerRef, resolveKey, 710 keyFn); 711 }; 712 713 // Read the external resources from the bytecode. 714 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 715 StringRef key; 716 if (failed(stringReader.parseString(offsetReader, key))) 717 return failure(); 718 719 // Get the handler for these resources. 720 // TODO: Should we require handling external resources in some scenarios? 721 AsmResourceParser *handler = config.getResourceParser(key); 722 if (!handler) { 723 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 724 << "'"; 725 } 726 727 if (failed(parseGroup(handler))) 728 return failure(); 729 } 730 731 // Read the dialect resources from the bytecode. 732 MLIRContext *ctx = fileLoc->getContext(); 733 while (!offsetReader.empty()) { 734 BytecodeDialect *dialect; 735 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 736 failed(dialect->load(dialectReader, ctx))) 737 return failure(); 738 Dialect *loadedDialect = dialect->getLoadedDialect(); 739 if (!loadedDialect) { 740 return resourceReader.emitError() 741 << "dialect '" << dialect->name << "' is unknown"; 742 } 743 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 744 if (!handler) { 745 return resourceReader.emitError() 746 << "unexpected resources for dialect '" << dialect->name << "'"; 747 } 748 749 // Ensure that each resource is declared before being processed. 750 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 751 FailureOr<AsmDialectResourceHandle> handle = 752 handler->declareResource(key); 753 if (failed(handle)) { 754 return resourceReader.emitError() 755 << "unknown 'resource' key '" << key << "' for dialect '" 756 << dialect->name << "'"; 757 } 758 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 759 dialectResources.push_back(*handle); 760 return success(); 761 }; 762 763 // Parse the resources for this dialect. We allow empty resources because we 764 // just treat these as declarations. 765 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 766 return failure(); 767 } 768 769 return success(); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // Attribute/Type Reader 774 //===----------------------------------------------------------------------===// 775 776 namespace { 777 /// This class provides support for reading attribute and type entries from the 778 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 779 /// this reader to manage when to actually parse them from the bytecode. 780 class AttrTypeReader { 781 /// This class represents a single attribute or type entry. 782 template <typename T> 783 struct Entry { 784 /// The entry, or null if it hasn't been resolved yet. 785 T entry = {}; 786 /// The parent dialect of this entry. 787 BytecodeDialect *dialect = nullptr; 788 /// A flag indicating if the entry was encoded using a custom encoding, 789 /// instead of using the textual assembly format. 790 bool hasCustomEncoding = false; 791 /// The raw data of this entry in the bytecode. 792 ArrayRef<uint8_t> data; 793 }; 794 using AttrEntry = Entry<Attribute>; 795 using TypeEntry = Entry<Type>; 796 797 public: 798 AttrTypeReader(StringSectionReader &stringReader, 799 ResourceSectionReader &resourceReader, Location fileLoc, 800 uint64_t &bytecodeVersion) 801 : stringReader(stringReader), resourceReader(resourceReader), 802 fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} 803 804 /// Initialize the attribute and type information within the reader. 805 LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, 806 ArrayRef<uint8_t> sectionData, 807 ArrayRef<uint8_t> offsetSectionData); 808 809 /// Resolve the attribute or type at the given index. Returns nullptr on 810 /// failure. 811 Attribute resolveAttribute(size_t index) { 812 return resolveEntry(attributes, index, "Attribute"); 813 } 814 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 815 816 /// Parse a reference to an attribute or type using the given reader. 817 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 818 uint64_t attrIdx; 819 if (failed(reader.parseVarInt(attrIdx))) 820 return failure(); 821 result = resolveAttribute(attrIdx); 822 return success(!!result); 823 } 824 LogicalResult parseOptionalAttribute(EncodingReader &reader, 825 Attribute &result) { 826 uint64_t attrIdx; 827 bool flag; 828 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 829 return failure(); 830 if (!flag) 831 return success(); 832 result = resolveAttribute(attrIdx); 833 return success(!!result); 834 } 835 836 LogicalResult parseType(EncodingReader &reader, Type &result) { 837 uint64_t typeIdx; 838 if (failed(reader.parseVarInt(typeIdx))) 839 return failure(); 840 result = resolveType(typeIdx); 841 return success(!!result); 842 } 843 844 template <typename T> 845 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 846 Attribute baseResult; 847 if (failed(parseAttribute(reader, baseResult))) 848 return failure(); 849 if ((result = dyn_cast<T>(baseResult))) 850 return success(); 851 return reader.emitError("expected attribute of type: ", 852 llvm::getTypeName<T>(), ", but got: ", baseResult); 853 } 854 855 private: 856 /// Resolve the given entry at `index`. 857 template <typename T> 858 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 859 StringRef entryType); 860 861 /// Parse an entry using the given reader that was encoded using the textual 862 /// assembly format. 863 template <typename T> 864 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 865 StringRef entryType); 866 867 /// Parse an entry using the given reader that was encoded using a custom 868 /// bytecode format. 869 template <typename T> 870 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 871 StringRef entryType); 872 873 /// The string section reader used to resolve string references when parsing 874 /// custom encoded attribute/type entries. 875 StringSectionReader &stringReader; 876 877 /// The resource section reader used to resolve resource references when 878 /// parsing custom encoded attribute/type entries. 879 ResourceSectionReader &resourceReader; 880 881 /// The set of attribute and type entries. 882 SmallVector<AttrEntry> attributes; 883 SmallVector<TypeEntry> types; 884 885 /// A location used for error emission. 886 Location fileLoc; 887 888 /// Current bytecode version being used. 889 uint64_t &bytecodeVersion; 890 }; 891 892 class DialectReader : public DialectBytecodeReader { 893 public: 894 DialectReader(AttrTypeReader &attrTypeReader, 895 StringSectionReader &stringReader, 896 ResourceSectionReader &resourceReader, EncodingReader &reader, 897 uint64_t &bytecodeVersion) 898 : attrTypeReader(attrTypeReader), stringReader(stringReader), 899 resourceReader(resourceReader), reader(reader), 900 bytecodeVersion(bytecodeVersion) {} 901 902 InFlightDiagnostic emitError(const Twine &msg) override { 903 return reader.emitError(msg); 904 } 905 906 uint64_t getBytecodeVersion() const override { return bytecodeVersion; } 907 908 DialectReader withEncodingReader(EncodingReader &encReader) { 909 return DialectReader(attrTypeReader, stringReader, resourceReader, 910 encReader, bytecodeVersion); 911 } 912 913 Location getLoc() const { return reader.getLoc(); } 914 915 //===--------------------------------------------------------------------===// 916 // IR 917 //===--------------------------------------------------------------------===// 918 919 LogicalResult readAttribute(Attribute &result) override { 920 return attrTypeReader.parseAttribute(reader, result); 921 } 922 LogicalResult readOptionalAttribute(Attribute &result) override { 923 return attrTypeReader.parseOptionalAttribute(reader, result); 924 } 925 LogicalResult readType(Type &result) override { 926 return attrTypeReader.parseType(reader, result); 927 } 928 929 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 930 AsmDialectResourceHandle handle; 931 if (failed(resourceReader.parseResourceHandle(reader, handle))) 932 return failure(); 933 return handle; 934 } 935 936 //===--------------------------------------------------------------------===// 937 // Primitives 938 //===--------------------------------------------------------------------===// 939 940 LogicalResult readVarInt(uint64_t &result) override { 941 return reader.parseVarInt(result); 942 } 943 944 LogicalResult readSignedVarInt(int64_t &result) override { 945 uint64_t unsignedResult; 946 if (failed(reader.parseSignedVarInt(unsignedResult))) 947 return failure(); 948 result = static_cast<int64_t>(unsignedResult); 949 return success(); 950 } 951 952 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 953 // Small values are encoded using a single byte. 954 if (bitWidth <= 8) { 955 uint8_t value; 956 if (failed(reader.parseByte(value))) 957 return failure(); 958 return APInt(bitWidth, value); 959 } 960 961 // Large values up to 64 bits are encoded using a single varint. 962 if (bitWidth <= 64) { 963 uint64_t value; 964 if (failed(reader.parseSignedVarInt(value))) 965 return failure(); 966 return APInt(bitWidth, value); 967 } 968 969 // Otherwise, for really big values we encode the array of active words in 970 // the value. 971 uint64_t numActiveWords; 972 if (failed(reader.parseVarInt(numActiveWords))) 973 return failure(); 974 SmallVector<uint64_t, 4> words(numActiveWords); 975 for (uint64_t i = 0; i < numActiveWords; ++i) 976 if (failed(reader.parseSignedVarInt(words[i]))) 977 return failure(); 978 return APInt(bitWidth, words); 979 } 980 981 FailureOr<APFloat> 982 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 983 FailureOr<APInt> intVal = 984 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 985 if (failed(intVal)) 986 return failure(); 987 return APFloat(semantics, *intVal); 988 } 989 990 LogicalResult readString(StringRef &result) override { 991 return stringReader.parseString(reader, result); 992 } 993 994 LogicalResult readBlob(ArrayRef<char> &result) override { 995 uint64_t dataSize; 996 ArrayRef<uint8_t> data; 997 if (failed(reader.parseVarInt(dataSize)) || 998 failed(reader.parseBytes(dataSize, data))) 999 return failure(); 1000 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 1001 data.size()); 1002 return success(); 1003 } 1004 1005 LogicalResult readBool(bool &result) override { 1006 return reader.parseByte(result); 1007 } 1008 1009 private: 1010 AttrTypeReader &attrTypeReader; 1011 StringSectionReader &stringReader; 1012 ResourceSectionReader &resourceReader; 1013 EncodingReader &reader; 1014 uint64_t &bytecodeVersion; 1015 }; 1016 1017 /// Wraps the properties section and handles reading properties out of it. 1018 class PropertiesSectionReader { 1019 public: 1020 /// Initialize the properties section reader with the given section data. 1021 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1022 if (sectionData.empty()) 1023 return success(); 1024 EncodingReader propReader(sectionData, fileLoc); 1025 uint64_t count; 1026 if (failed(propReader.parseVarInt(count))) 1027 return failure(); 1028 // Parse the raw properties buffer. 1029 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1030 return failure(); 1031 1032 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1033 offsetTable.reserve(count); 1034 for (auto idx : llvm::seq<int64_t>(0, count)) { 1035 (void)idx; 1036 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1037 ArrayRef<uint8_t> rawProperties; 1038 uint64_t dataSize; 1039 if (failed(offsetsReader.parseVarInt(dataSize)) || 1040 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1041 return failure(); 1042 } 1043 if (!offsetsReader.empty()) 1044 return offsetsReader.emitError() 1045 << "Broken properties section: didn't exhaust the offsets table"; 1046 return success(); 1047 } 1048 1049 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1050 OperationName *opName, OperationState &opState) { 1051 uint64_t propertiesIdx; 1052 if (failed(dialectReader.readVarInt(propertiesIdx))) 1053 return failure(); 1054 if (propertiesIdx >= offsetTable.size()) 1055 return dialectReader.emitError("Properties idx out-of-bound for ") 1056 << opName->getStringRef(); 1057 size_t propertiesOffset = offsetTable[propertiesIdx]; 1058 if (propertiesIdx >= propertiesBuffers.size()) 1059 return dialectReader.emitError("Properties offset out-of-bound for ") 1060 << opName->getStringRef(); 1061 1062 // Acquire the sub-buffer that represent the requested properties. 1063 ArrayRef<char> rawProperties; 1064 { 1065 // "Seek" to the requested offset by getting a new reader with the right 1066 // sub-buffer. 1067 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1068 fileLoc); 1069 // Properties are stored as a sequence of {size + raw_data}. 1070 if (failed( 1071 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1072 return failure(); 1073 } 1074 // Setup a new reader to read from the `rawProperties` sub-buffer. 1075 EncodingReader reader( 1076 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1077 DialectReader propReader = dialectReader.withEncodingReader(reader); 1078 1079 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1080 if (iface) 1081 return iface->readProperties(propReader, opState); 1082 if (opName->isRegistered()) 1083 return propReader.emitError( 1084 "has properties but missing BytecodeOpInterface for ") 1085 << opName->getStringRef(); 1086 // Unregistered op are storing properties as an attribute. 1087 return propReader.readAttribute(opState.propertiesAttr); 1088 } 1089 1090 private: 1091 /// The properties buffer referenced within the bytecode file. 1092 ArrayRef<uint8_t> propertiesBuffers; 1093 1094 /// Table of offset in the buffer above. 1095 SmallVector<int64_t> offsetTable; 1096 }; 1097 } // namespace 1098 1099 LogicalResult 1100 AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, 1101 ArrayRef<uint8_t> sectionData, 1102 ArrayRef<uint8_t> offsetSectionData) { 1103 EncodingReader offsetReader(offsetSectionData, fileLoc); 1104 1105 // Parse the number of attribute and type entries. 1106 uint64_t numAttributes, numTypes; 1107 if (failed(offsetReader.parseVarInt(numAttributes)) || 1108 failed(offsetReader.parseVarInt(numTypes))) 1109 return failure(); 1110 attributes.resize(numAttributes); 1111 types.resize(numTypes); 1112 1113 // A functor used to accumulate the offsets for the entries in the given 1114 // range. 1115 uint64_t currentOffset = 0; 1116 auto parseEntries = [&](auto &&range) { 1117 size_t currentIndex = 0, endIndex = range.size(); 1118 1119 // Parse an individual entry. 1120 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1121 auto &entry = range[currentIndex++]; 1122 1123 uint64_t entrySize; 1124 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1125 entry.hasCustomEncoding))) 1126 return failure(); 1127 1128 // Verify that the offset is actually valid. 1129 if (currentOffset + entrySize > sectionData.size()) { 1130 return offsetReader.emitError( 1131 "Attribute or Type entry offset points past the end of section"); 1132 } 1133 1134 entry.data = sectionData.slice(currentOffset, entrySize); 1135 entry.dialect = dialect; 1136 currentOffset += entrySize; 1137 return success(); 1138 }; 1139 while (currentIndex != endIndex) 1140 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1141 return failure(); 1142 return success(); 1143 }; 1144 1145 // Process each of the attributes, and then the types. 1146 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1147 return failure(); 1148 1149 // Ensure that we read everything from the section. 1150 if (!offsetReader.empty()) { 1151 return offsetReader.emitError( 1152 "unexpected trailing data in the Attribute/Type offset section"); 1153 } 1154 return success(); 1155 } 1156 1157 template <typename T> 1158 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1159 StringRef entryType) { 1160 if (index >= entries.size()) { 1161 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1162 return {}; 1163 } 1164 1165 // If the entry has already been resolved, there is nothing left to do. 1166 Entry<T> &entry = entries[index]; 1167 if (entry.entry) 1168 return entry.entry; 1169 1170 // Parse the entry. 1171 EncodingReader reader(entry.data, fileLoc); 1172 1173 // Parse based on how the entry was encoded. 1174 if (entry.hasCustomEncoding) { 1175 if (failed(parseCustomEntry(entry, reader, entryType))) 1176 return T(); 1177 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1178 return T(); 1179 } 1180 1181 if (!reader.empty()) { 1182 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1183 return T(); 1184 } 1185 return entry.entry; 1186 } 1187 1188 template <typename T> 1189 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1190 StringRef entryType) { 1191 StringRef asmStr; 1192 if (failed(reader.parseNullTerminatedString(asmStr))) 1193 return failure(); 1194 1195 // Invoke the MLIR assembly parser to parse the entry text. 1196 size_t numRead = 0; 1197 MLIRContext *context = fileLoc->getContext(); 1198 if constexpr (std::is_same_v<T, Type>) 1199 result = 1200 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1201 else 1202 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1203 /*isKnownNullTerminated=*/true); 1204 if (!result) 1205 return failure(); 1206 1207 // Ensure there weren't dangling characters after the entry. 1208 if (numRead != asmStr.size()) { 1209 return reader.emitError("trailing characters found after ", entryType, 1210 " assembly format: ", asmStr.drop_front(numRead)); 1211 } 1212 return success(); 1213 } 1214 1215 template <typename T> 1216 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1217 EncodingReader &reader, 1218 StringRef entryType) { 1219 DialectReader dialectReader(*this, stringReader, resourceReader, reader, 1220 bytecodeVersion); 1221 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1222 return failure(); 1223 // Ensure that the dialect implements the bytecode interface. 1224 if (!entry.dialect->interface) { 1225 return reader.emitError("dialect '", entry.dialect->name, 1226 "' does not implement the bytecode interface"); 1227 } 1228 1229 // Ask the dialect to parse the entry. If the dialect is versioned, parse 1230 // using the versioned encoding readers. 1231 if (entry.dialect->loadedVersion.get()) { 1232 if constexpr (std::is_same_v<T, Type>) 1233 entry.entry = entry.dialect->interface->readType( 1234 dialectReader, *entry.dialect->loadedVersion); 1235 else 1236 entry.entry = entry.dialect->interface->readAttribute( 1237 dialectReader, *entry.dialect->loadedVersion); 1238 1239 } else { 1240 if constexpr (std::is_same_v<T, Type>) 1241 entry.entry = entry.dialect->interface->readType(dialectReader); 1242 else 1243 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1244 } 1245 return success(!!entry.entry); 1246 } 1247 1248 //===----------------------------------------------------------------------===// 1249 // Bytecode Reader 1250 //===----------------------------------------------------------------------===// 1251 1252 /// This class is used to read a bytecode buffer and translate it into MLIR. 1253 class mlir::BytecodeReader::Impl { 1254 struct RegionReadState; 1255 using LazyLoadableOpsInfo = 1256 std::list<std::pair<Operation *, RegionReadState>>; 1257 using LazyLoadableOpsMap = 1258 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1259 1260 public: 1261 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1262 llvm::MemoryBufferRef buffer, 1263 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1264 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1265 attrTypeReader(stringReader, resourceReader, fileLoc, version), 1266 // Use the builtin unrealized conversion cast operation to represent 1267 // forward references to values that aren't yet defined. 1268 forwardRefOpState(UnknownLoc::get(config.getContext()), 1269 "builtin.unrealized_conversion_cast", ValueRange(), 1270 NoneType::get(config.getContext())), 1271 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1272 1273 /// Read the bytecode defined within `buffer` into the given block. 1274 LogicalResult read(Block *block, 1275 llvm::function_ref<bool(Operation *)> lazyOps); 1276 1277 /// Return the number of ops that haven't been materialized yet. 1278 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1279 1280 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1281 1282 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1283 /// newly found lazy operation. 1284 LogicalResult 1285 materialize(Operation *op, 1286 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1287 this->lazyOpsCallback = lazyOpsCallback; 1288 auto resetlazyOpsCallback = 1289 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1290 auto it = lazyLoadableOpsMap.find(op); 1291 assert(it != lazyLoadableOpsMap.end() && 1292 "materialize called on non-materializable op"); 1293 return materialize(it); 1294 } 1295 1296 /// Materialize all operations. 1297 LogicalResult materializeAll() { 1298 while (!lazyLoadableOpsMap.empty()) { 1299 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1300 return failure(); 1301 } 1302 return success(); 1303 } 1304 1305 /// Finalize the lazy-loading by calling back with every op that hasn't been 1306 /// materialized to let the client decide if the op should be deleted or 1307 /// materialized. The op is materialized if the callback returns true, deleted 1308 /// otherwise. 1309 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1310 while (!lazyLoadableOps.empty()) { 1311 Operation *op = lazyLoadableOps.begin()->first; 1312 if (shouldMaterialize(op)) { 1313 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1314 return failure(); 1315 continue; 1316 } 1317 op->dropAllReferences(); 1318 op->erase(); 1319 lazyLoadableOps.pop_front(); 1320 lazyLoadableOpsMap.erase(op); 1321 } 1322 return success(); 1323 } 1324 1325 private: 1326 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1327 assert(it != lazyLoadableOpsMap.end() && 1328 "materialize called on non-materializable op"); 1329 valueScopes.emplace_back(); 1330 std::vector<RegionReadState> regionStack; 1331 regionStack.push_back(std::move(it->getSecond()->second)); 1332 lazyLoadableOps.erase(it->getSecond()); 1333 lazyLoadableOpsMap.erase(it); 1334 1335 while (!regionStack.empty()) 1336 if (failed(parseRegions(regionStack, regionStack.back()))) 1337 return failure(); 1338 return success(); 1339 } 1340 1341 /// Return the context for this config. 1342 MLIRContext *getContext() const { return config.getContext(); } 1343 1344 /// Parse the bytecode version. 1345 LogicalResult parseVersion(EncodingReader &reader); 1346 1347 //===--------------------------------------------------------------------===// 1348 // Dialect Section 1349 1350 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1351 1352 /// Parse an operation name reference using the given reader, and set the 1353 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1354 /// context where opName was registered. 1355 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1356 std::optional<bool> &wasRegistered); 1357 1358 //===--------------------------------------------------------------------===// 1359 // Attribute/Type Section 1360 1361 /// Parse an attribute or type using the given reader. 1362 template <typename T> 1363 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1364 return attrTypeReader.parseAttribute(reader, result); 1365 } 1366 LogicalResult parseType(EncodingReader &reader, Type &result) { 1367 return attrTypeReader.parseType(reader, result); 1368 } 1369 1370 //===--------------------------------------------------------------------===// 1371 // Resource Section 1372 1373 LogicalResult 1374 parseResourceSection(EncodingReader &reader, 1375 std::optional<ArrayRef<uint8_t>> resourceData, 1376 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1377 1378 //===--------------------------------------------------------------------===// 1379 // IR Section 1380 1381 /// This struct represents the current read state of a range of regions. This 1382 /// struct is used to enable iterative parsing of regions. 1383 struct RegionReadState { 1384 RegionReadState(Operation *op, EncodingReader *reader, 1385 bool isIsolatedFromAbove) 1386 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1387 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1388 bool isIsolatedFromAbove) 1389 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1390 isIsolatedFromAbove(isIsolatedFromAbove) {} 1391 1392 /// The current regions being read. 1393 MutableArrayRef<Region>::iterator curRegion, endRegion; 1394 /// This is the reader to use for this region, this pointer is pointing to 1395 /// the parent region reader unless the current region is IsolatedFromAbove, 1396 /// in which case the pointer is pointing to the `owningReader` which is a 1397 /// section dedicated to the current region. 1398 EncodingReader *reader; 1399 std::unique_ptr<EncodingReader> owningReader; 1400 1401 /// The number of values defined immediately within this region. 1402 unsigned numValues = 0; 1403 1404 /// The current blocks of the region being read. 1405 SmallVector<Block *> curBlocks; 1406 Region::iterator curBlock = {}; 1407 1408 /// The number of operations remaining to be read from the current block 1409 /// being read. 1410 uint64_t numOpsRemaining = 0; 1411 1412 /// A flag indicating if the regions being read are isolated from above. 1413 bool isIsolatedFromAbove = false; 1414 }; 1415 1416 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1417 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1418 RegionReadState &readState); 1419 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1420 RegionReadState &readState, 1421 bool &isIsolatedFromAbove); 1422 1423 LogicalResult parseRegion(RegionReadState &readState); 1424 LogicalResult parseBlockHeader(EncodingReader &reader, 1425 RegionReadState &readState); 1426 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1427 1428 //===--------------------------------------------------------------------===// 1429 // Value Processing 1430 1431 /// Parse an operand reference using the given reader. Returns nullptr in the 1432 /// case of failure. 1433 Value parseOperand(EncodingReader &reader); 1434 1435 /// Sequentially define the given value range. 1436 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1437 1438 /// Create a value to use for a forward reference. 1439 Value createForwardRef(); 1440 1441 //===--------------------------------------------------------------------===// 1442 // Use-list order helpers 1443 1444 /// This struct is a simple storage that contains information required to 1445 /// reorder the use-list of a value with respect to the pre-order traversal 1446 /// ordering. 1447 struct UseListOrderStorage { 1448 UseListOrderStorage(bool isIndexPairEncoding, 1449 SmallVector<unsigned, 4> &&indices) 1450 : indices(std::move(indices)), 1451 isIndexPairEncoding(isIndexPairEncoding){}; 1452 /// The vector containing the information required to reorder the 1453 /// use-list of a value. 1454 SmallVector<unsigned, 4> indices; 1455 1456 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1457 /// indexing, such as `dst = order[src]`. 1458 bool isIndexPairEncoding; 1459 }; 1460 1461 /// Parse use-list order from bytecode for a range of values if available. The 1462 /// range is expected to be either a block argument or an op result range. On 1463 /// success, return a map of the position in the range and the use-list order 1464 /// encoding. The function assumes to know the size of the range it is 1465 /// processing. 1466 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1467 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1468 uint64_t rangeSize); 1469 1470 /// Shuffle the use-chain according to the order parsed. 1471 LogicalResult sortUseListOrder(Value value); 1472 1473 /// Recursively visit all the values defined within topLevelOp and sort the 1474 /// use-list orders according to the indices parsed. 1475 LogicalResult processUseLists(Operation *topLevelOp); 1476 1477 //===--------------------------------------------------------------------===// 1478 // Fields 1479 1480 /// This class represents a single value scope, in which a value scope is 1481 /// delimited by isolated from above regions. 1482 struct ValueScope { 1483 /// Push a new region state onto this scope, reserving enough values for 1484 /// those defined within the current region of the provided state. 1485 void push(RegionReadState &readState) { 1486 nextValueIDs.push_back(values.size()); 1487 values.resize(values.size() + readState.numValues); 1488 } 1489 1490 /// Pop the values defined for the current region within the provided region 1491 /// state. 1492 void pop(RegionReadState &readState) { 1493 values.resize(values.size() - readState.numValues); 1494 nextValueIDs.pop_back(); 1495 } 1496 1497 /// The set of values defined in this scope. 1498 std::vector<Value> values; 1499 1500 /// The ID for the next defined value for each region current being 1501 /// processed in this scope. 1502 SmallVector<unsigned, 4> nextValueIDs; 1503 }; 1504 1505 /// The configuration of the parser. 1506 const ParserConfig &config; 1507 1508 /// A location to use when emitting errors. 1509 Location fileLoc; 1510 1511 /// Flag that indicates if lazyloading is enabled. 1512 bool lazyLoading; 1513 1514 /// Keep track of operations that have been lazy loaded (their regions haven't 1515 /// been materialized), along with the `RegionReadState` that allows to 1516 /// lazy-load the regions nested under the operation. 1517 LazyLoadableOpsInfo lazyLoadableOps; 1518 LazyLoadableOpsMap lazyLoadableOpsMap; 1519 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1520 1521 /// The reader used to process attribute and types within the bytecode. 1522 AttrTypeReader attrTypeReader; 1523 1524 /// The version of the bytecode being read. 1525 uint64_t version = 0; 1526 1527 /// The producer of the bytecode being read. 1528 StringRef producer; 1529 1530 /// The table of IR units referenced within the bytecode file. 1531 SmallVector<BytecodeDialect> dialects; 1532 SmallVector<BytecodeOperationName> opNames; 1533 1534 /// The reader used to process resources within the bytecode. 1535 ResourceSectionReader resourceReader; 1536 1537 /// Worklist of values with custom use-list orders to process before the end 1538 /// of the parsing. 1539 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1540 1541 /// The table of strings referenced within the bytecode file. 1542 StringSectionReader stringReader; 1543 1544 /// The table of properties referenced by the operation in the bytecode file. 1545 PropertiesSectionReader propertiesReader; 1546 1547 /// The current set of available IR value scopes. 1548 std::vector<ValueScope> valueScopes; 1549 1550 /// The global pre-order operation ordering. 1551 DenseMap<Operation *, unsigned> operationIDs; 1552 1553 /// A block containing the set of operations defined to create forward 1554 /// references. 1555 Block forwardRefOps; 1556 1557 /// A block containing previously created, and no longer used, forward 1558 /// reference operations. 1559 Block openForwardRefOps; 1560 1561 /// An operation state used when instantiating forward references. 1562 OperationState forwardRefOpState; 1563 1564 /// Reference to the input buffer. 1565 llvm::MemoryBufferRef buffer; 1566 1567 /// The optional owning source manager, which when present may be used to 1568 /// extend the lifetime of the input buffer. 1569 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1570 }; 1571 1572 LogicalResult BytecodeReader::Impl::read( 1573 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1574 EncodingReader reader(buffer.getBuffer(), fileLoc); 1575 this->lazyOpsCallback = lazyOpsCallback; 1576 auto resetlazyOpsCallback = 1577 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1578 1579 // Skip over the bytecode header, this should have already been checked. 1580 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1581 return failure(); 1582 // Parse the bytecode version and producer. 1583 if (failed(parseVersion(reader)) || 1584 failed(reader.parseNullTerminatedString(producer))) 1585 return failure(); 1586 1587 // Add a diagnostic handler that attaches a note that includes the original 1588 // producer of the bytecode. 1589 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1590 diag.attachNote() << "in bytecode version " << version 1591 << " produced by: " << producer; 1592 return failure(); 1593 }); 1594 1595 // Parse the raw data for each of the top-level sections of the bytecode. 1596 std::optional<ArrayRef<uint8_t>> 1597 sectionDatas[bytecode::Section::kNumSections]; 1598 while (!reader.empty()) { 1599 // Read the next section from the bytecode. 1600 bytecode::Section::ID sectionID; 1601 ArrayRef<uint8_t> sectionData; 1602 if (failed(reader.parseSection(sectionID, sectionData))) 1603 return failure(); 1604 1605 // Check for duplicate sections, we only expect one instance of each. 1606 if (sectionDatas[sectionID]) { 1607 return reader.emitError("duplicate top-level section: ", 1608 ::toString(sectionID)); 1609 } 1610 sectionDatas[sectionID] = sectionData; 1611 } 1612 // Check that all of the required sections were found. 1613 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1614 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1615 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1616 return reader.emitError("missing data for top-level section: ", 1617 ::toString(sectionID)); 1618 } 1619 } 1620 1621 // Process the string section first. 1622 if (failed(stringReader.initialize( 1623 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1624 return failure(); 1625 1626 // Process the properties section. 1627 if (sectionDatas[bytecode::Section::kProperties] && 1628 failed(propertiesReader.initialize( 1629 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1630 return failure(); 1631 1632 // Process the dialect section. 1633 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1634 return failure(); 1635 1636 // Process the resource section if present. 1637 if (failed(parseResourceSection( 1638 reader, sectionDatas[bytecode::Section::kResource], 1639 sectionDatas[bytecode::Section::kResourceOffset]))) 1640 return failure(); 1641 1642 // Process the attribute and type section. 1643 if (failed(attrTypeReader.initialize( 1644 dialects, *sectionDatas[bytecode::Section::kAttrType], 1645 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1646 return failure(); 1647 1648 // Finally, process the IR section. 1649 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1650 } 1651 1652 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1653 if (failed(reader.parseVarInt(version))) 1654 return failure(); 1655 1656 // Validate the bytecode version. 1657 uint64_t currentVersion = bytecode::kVersion; 1658 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1659 if (version < minSupportedVersion) { 1660 return reader.emitError("bytecode version ", version, 1661 " is older than the current version of ", 1662 currentVersion, ", and upgrade is not supported"); 1663 } 1664 if (version > currentVersion) { 1665 return reader.emitError("bytecode version ", version, 1666 " is newer than the current version ", 1667 currentVersion); 1668 } 1669 // Override any request to lazy-load if the bytecode version is too old. 1670 if (version < bytecode::kLazyLoading) 1671 lazyLoading = false; 1672 return success(); 1673 } 1674 1675 //===----------------------------------------------------------------------===// 1676 // Dialect Section 1677 1678 LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { 1679 if (dialect) 1680 return success(); 1681 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1682 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1683 return reader.emitError("dialect '") 1684 << name 1685 << "' is unknown. If this is intended, please call " 1686 "allowUnregisteredDialects() on the MLIRContext, or use " 1687 "-allow-unregistered-dialect with the MLIR tool used."; 1688 } 1689 dialect = loadedDialect; 1690 1691 // If the dialect was actually loaded, check to see if it has a bytecode 1692 // interface. 1693 if (loadedDialect) 1694 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1695 if (!versionBuffer.empty()) { 1696 if (!interface) 1697 return reader.emitError("dialect '") 1698 << name 1699 << "' does not implement the bytecode interface, " 1700 "but found a version entry"; 1701 EncodingReader encReader(versionBuffer, reader.getLoc()); 1702 DialectReader versionReader = reader.withEncodingReader(encReader); 1703 loadedVersion = interface->readVersion(versionReader); 1704 if (!loadedVersion) 1705 return failure(); 1706 } 1707 return success(); 1708 } 1709 1710 LogicalResult 1711 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1712 EncodingReader sectionReader(sectionData, fileLoc); 1713 1714 // Parse the number of dialects in the section. 1715 uint64_t numDialects; 1716 if (failed(sectionReader.parseVarInt(numDialects))) 1717 return failure(); 1718 dialects.resize(numDialects); 1719 1720 // Parse each of the dialects. 1721 for (uint64_t i = 0; i < numDialects; ++i) { 1722 /// Before version kDialectVersioning, there wasn't any versioning available 1723 /// for dialects, and the entryIdx represent the string itself. 1724 if (version < bytecode::kDialectVersioning) { 1725 if (failed(stringReader.parseString(sectionReader, dialects[i].name))) 1726 return failure(); 1727 continue; 1728 } 1729 // Parse ID representing dialect and version. 1730 uint64_t dialectNameIdx; 1731 bool versionAvailable; 1732 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1733 versionAvailable))) 1734 return failure(); 1735 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1736 dialects[i].name))) 1737 return failure(); 1738 if (versionAvailable) { 1739 bytecode::Section::ID sectionID; 1740 if (failed( 1741 sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) 1742 return failure(); 1743 if (sectionID != bytecode::Section::kDialectVersions) { 1744 emitError(fileLoc, "expected dialect version section"); 1745 return failure(); 1746 } 1747 } 1748 } 1749 1750 // Parse the operation names, which are grouped by dialect. 1751 auto parseOpName = [&](BytecodeDialect *dialect) { 1752 StringRef opName; 1753 std::optional<bool> wasRegistered; 1754 // Prior to version kNativePropertiesEncoding, the information about wheter 1755 // an op was registered or not wasn't encoded. 1756 if (version < bytecode::kNativePropertiesEncoding) { 1757 if (failed(stringReader.parseString(sectionReader, opName))) 1758 return failure(); 1759 } else { 1760 bool wasRegisteredFlag; 1761 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1762 wasRegisteredFlag))) 1763 return failure(); 1764 wasRegistered = wasRegisteredFlag; 1765 } 1766 opNames.emplace_back(dialect, opName, wasRegistered); 1767 return success(); 1768 }; 1769 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation 1770 // where the number of ops are known. 1771 if (version >= bytecode::kElideUnknownBlockArgLocation) { 1772 uint64_t numOps; 1773 if (failed(sectionReader.parseVarInt(numOps))) 1774 return failure(); 1775 opNames.reserve(numOps); 1776 } 1777 while (!sectionReader.empty()) 1778 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1779 return failure(); 1780 return success(); 1781 } 1782 1783 FailureOr<OperationName> 1784 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1785 std::optional<bool> &wasRegistered) { 1786 BytecodeOperationName *opName = nullptr; 1787 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1788 return failure(); 1789 wasRegistered = opName->wasRegistered; 1790 // Check to see if this operation name has already been resolved. If we 1791 // haven't, load the dialect and build the operation name. 1792 if (!opName->opName) { 1793 // Load the dialect and its version. 1794 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1795 reader, version); 1796 if (failed(opName->dialect->load(dialectReader, getContext()))) 1797 return failure(); 1798 // If the opName is empty, this is because we use to accept names such as 1799 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1800 // format anymore but for now we'll be backward compatible. This can only 1801 // happen with unregistered dialects. 1802 if (opName->name.empty()) { 1803 if (opName->dialect->getLoadedDialect()) 1804 return emitError(fileLoc) << "has an empty opname for dialect '" 1805 << opName->dialect->name << "'\n"; 1806 1807 opName->opName.emplace(opName->dialect->name, getContext()); 1808 } else { 1809 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1810 getContext()); 1811 } 1812 } 1813 return *opName->opName; 1814 } 1815 1816 //===----------------------------------------------------------------------===// 1817 // Resource Section 1818 1819 LogicalResult BytecodeReader::Impl::parseResourceSection( 1820 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1821 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1822 // Ensure both sections are either present or not. 1823 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1824 if (resourceOffsetData) 1825 return emitError(fileLoc, "unexpected resource offset section when " 1826 "resource section is not present"); 1827 return emitError( 1828 fileLoc, 1829 "expected resource offset section when resource section is present"); 1830 } 1831 1832 // If the resource sections are absent, there is nothing to do. 1833 if (!resourceData) 1834 return success(); 1835 1836 // Initialize the resource reader with the resource sections. 1837 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1838 reader, version); 1839 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1840 *resourceData, *resourceOffsetData, 1841 dialectReader, bufferOwnerRef); 1842 } 1843 1844 //===----------------------------------------------------------------------===// 1845 // UseListOrder Helpers 1846 1847 FailureOr<BytecodeReader::Impl::UseListMapT> 1848 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1849 uint64_t numResults) { 1850 BytecodeReader::Impl::UseListMapT map; 1851 uint64_t numValuesToRead = 1; 1852 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1853 return failure(); 1854 1855 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1856 uint64_t resultIdx = 0; 1857 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1858 return failure(); 1859 1860 uint64_t numValues; 1861 bool indexPairEncoding; 1862 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1863 return failure(); 1864 1865 SmallVector<unsigned, 4> useListOrders; 1866 for (size_t idx = 0; idx < numValues; idx++) { 1867 uint64_t index; 1868 if (failed(reader.parseVarInt(index))) 1869 return failure(); 1870 useListOrders.push_back(index); 1871 } 1872 1873 // Store in a map the result index 1874 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1875 std::move(useListOrders))); 1876 } 1877 1878 return map; 1879 } 1880 1881 /// Sorts each use according to the order specified in the use-list parsed. If 1882 /// the custom use-list is not found, this means that the order needs to be 1883 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1884 /// on the same operation, the order will follow the reverse operand number 1885 /// ordering. 1886 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1887 // Early return for trivial use-lists. 1888 if (value.use_empty() || value.hasOneUse()) 1889 return success(); 1890 1891 bool hasIncomingOrder = 1892 valueToUseListMap.contains(value.getAsOpaquePointer()); 1893 1894 // Compute the current order of the use-list with respect to the global 1895 // ordering. Detect if the order is already sorted while doing so. 1896 bool alreadySorted = true; 1897 auto &firstUse = *value.use_begin(); 1898 uint64_t prevID = 1899 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1900 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1901 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1902 uint64_t currentID = bytecode::getUseID( 1903 item.value(), operationIDs.at(item.value().getOwner())); 1904 alreadySorted &= prevID > currentID; 1905 currentOrder.push_back({item.index(), currentID}); 1906 prevID = currentID; 1907 } 1908 1909 // If the order is already sorted, and there wasn't a custom order to apply 1910 // from the bytecode file, we are done. 1911 if (alreadySorted && !hasIncomingOrder) 1912 return success(); 1913 1914 // If not already sorted, sort the indices of the current order by descending 1915 // useIDs. 1916 if (!alreadySorted) 1917 std::sort( 1918 currentOrder.begin(), currentOrder.end(), 1919 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1920 1921 if (!hasIncomingOrder) { 1922 // If the bytecode file did not contain any custom use-list order, it means 1923 // that the order was descending useID. Hence, shuffle by the first index 1924 // of the `currentOrder` pair. 1925 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1926 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1927 value.shuffleUseList(shuffle); 1928 return success(); 1929 } 1930 1931 // Pull the custom order info from the map. 1932 UseListOrderStorage customOrder = 1933 valueToUseListMap.at(value.getAsOpaquePointer()); 1934 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1935 uint64_t numUses = 1936 std::distance(value.getUses().begin(), value.getUses().end()); 1937 1938 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1939 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1940 // as identity, and then apply the mapping encoded in the indices. 1941 if (customOrder.isIndexPairEncoding) { 1942 // Return failure if the number of indices was not representing pairs. 1943 if (shuffle.size() & 1) 1944 return failure(); 1945 1946 SmallVector<unsigned, 4> newShuffle(numUses); 1947 size_t idx = 0; 1948 std::iota(newShuffle.begin(), newShuffle.end(), idx); 1949 for (idx = 0; idx < shuffle.size(); idx += 2) 1950 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 1951 1952 shuffle = std::move(newShuffle); 1953 } 1954 1955 // Make sure that the indices represent a valid mapping. That is, the sum of 1956 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 1957 // duplicates are allowed in the list. 1958 DenseSet<unsigned> set; 1959 uint64_t accumulator = 0; 1960 for (const auto &elem : shuffle) { 1961 if (set.contains(elem)) 1962 return failure(); 1963 accumulator += elem; 1964 set.insert(elem); 1965 } 1966 if (numUses != shuffle.size() || 1967 accumulator != (((numUses - 1) * numUses) >> 1)) 1968 return failure(); 1969 1970 // Apply the current ordering map onto the shuffle vector to get the final 1971 // use-list sorting indices before shuffling. 1972 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 1973 currentOrder, [&](auto item) { return shuffle[item.first]; })); 1974 value.shuffleUseList(shuffle); 1975 return success(); 1976 } 1977 1978 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 1979 // Precompute operation IDs according to the pre-order walk of the IR. We 1980 // can't do this while parsing since parseRegions ordering is not strictly 1981 // equal to the pre-order walk. 1982 unsigned operationID = 0; 1983 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 1984 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 1985 1986 auto blockWalk = topLevelOp->walk([this](Block *block) { 1987 for (auto arg : block->getArguments()) 1988 if (failed(sortUseListOrder(arg))) 1989 return WalkResult::interrupt(); 1990 return WalkResult::advance(); 1991 }); 1992 1993 auto resultWalk = topLevelOp->walk([this](Operation *op) { 1994 for (auto result : op->getResults()) 1995 if (failed(sortUseListOrder(result))) 1996 return WalkResult::interrupt(); 1997 return WalkResult::advance(); 1998 }); 1999 2000 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 2001 } 2002 2003 //===----------------------------------------------------------------------===// 2004 // IR Section 2005 2006 LogicalResult 2007 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 2008 Block *block) { 2009 EncodingReader reader(sectionData, fileLoc); 2010 2011 // A stack of operation regions currently being read from the bytecode. 2012 std::vector<RegionReadState> regionStack; 2013 2014 // Parse the top-level block using a temporary module operation. 2015 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 2016 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 2017 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 2018 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 2019 if (failed(parseBlockHeader(reader, regionStack.back()))) 2020 return failure(); 2021 valueScopes.emplace_back(); 2022 valueScopes.back().push(regionStack.back()); 2023 2024 // Iteratively parse regions until everything has been resolved. 2025 while (!regionStack.empty()) 2026 if (failed(parseRegions(regionStack, regionStack.back()))) 2027 return failure(); 2028 if (!forwardRefOps.empty()) { 2029 return reader.emitError( 2030 "not all forward unresolved forward operand references"); 2031 } 2032 2033 // Sort use-lists according to what specified in bytecode. 2034 if (failed(processUseLists(*moduleOp))) 2035 return reader.emitError( 2036 "parsed use-list orders were invalid and could not be applied"); 2037 2038 // Resolve dialect version. 2039 for (const BytecodeDialect &byteCodeDialect : dialects) { 2040 // Parsing is complete, give an opportunity to each dialect to visit the 2041 // IR and perform upgrades. 2042 if (!byteCodeDialect.loadedVersion) 2043 continue; 2044 if (byteCodeDialect.interface && 2045 failed(byteCodeDialect.interface->upgradeFromVersion( 2046 *moduleOp, *byteCodeDialect.loadedVersion))) 2047 return failure(); 2048 } 2049 2050 // Verify that the parsed operations are valid. 2051 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2052 return failure(); 2053 2054 // Splice the parsed operations over to the provided top-level block. 2055 auto &parsedOps = moduleOp->getBody()->getOperations(); 2056 auto &destOps = block->getOperations(); 2057 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2058 return success(); 2059 } 2060 2061 LogicalResult 2062 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2063 RegionReadState &readState) { 2064 // Process regions, blocks, and operations until the end or if a nested 2065 // region is encountered. In this case we push a new state in regionStack and 2066 // return, the processing of the current region will resume afterward. 2067 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2068 // If the current block hasn't been setup yet, parse the header for this 2069 // region. The current block is already setup when this function was 2070 // interrupted to recurse down in a nested region and we resume the current 2071 // block after processing the nested region. 2072 if (readState.curBlock == Region::iterator()) { 2073 if (failed(parseRegion(readState))) 2074 return failure(); 2075 2076 // If the region is empty, there is nothing to more to do. 2077 if (readState.curRegion->empty()) 2078 continue; 2079 } 2080 2081 // Parse the blocks within the region. 2082 EncodingReader &reader = *readState.reader; 2083 do { 2084 while (readState.numOpsRemaining--) { 2085 // Read in the next operation. We don't read its regions directly, we 2086 // handle those afterwards as necessary. 2087 bool isIsolatedFromAbove = false; 2088 FailureOr<Operation *> op = 2089 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2090 if (failed(op)) 2091 return failure(); 2092 2093 // If the op has regions, add it to the stack for processing and return: 2094 // we stop the processing of the current region and resume it after the 2095 // inner one is completed. Unless LazyLoading is activated in which case 2096 // nested region parsing is delayed. 2097 if ((*op)->getNumRegions()) { 2098 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2099 2100 // Isolated regions are encoded as a section in version 2 and above. 2101 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { 2102 bytecode::Section::ID sectionID; 2103 ArrayRef<uint8_t> sectionData; 2104 if (failed(reader.parseSection(sectionID, sectionData))) 2105 return failure(); 2106 if (sectionID != bytecode::Section::kIR) 2107 return emitError(fileLoc, "expected IR section for region"); 2108 childState.owningReader = 2109 std::make_unique<EncodingReader>(sectionData, fileLoc); 2110 childState.reader = childState.owningReader.get(); 2111 2112 // If the user has a callback set, they have the opportunity to 2113 // control lazyloading as we go. 2114 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) { 2115 lazyLoadableOps.emplace_back(*op, std::move(childState)); 2116 lazyLoadableOpsMap.try_emplace(*op, 2117 std::prev(lazyLoadableOps.end())); 2118 continue; 2119 } 2120 } 2121 regionStack.push_back(std::move(childState)); 2122 2123 // If the op is isolated from above, push a new value scope. 2124 if (isIsolatedFromAbove) 2125 valueScopes.emplace_back(); 2126 return success(); 2127 } 2128 } 2129 2130 // Move to the next block of the region. 2131 if (++readState.curBlock == readState.curRegion->end()) 2132 break; 2133 if (failed(parseBlockHeader(reader, readState))) 2134 return failure(); 2135 } while (true); 2136 2137 // Reset the current block and any values reserved for this region. 2138 readState.curBlock = {}; 2139 valueScopes.back().pop(readState); 2140 } 2141 2142 // When the regions have been fully parsed, pop them off of the read stack. If 2143 // the regions were isolated from above, we also pop the last value scope. 2144 if (readState.isIsolatedFromAbove) { 2145 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2146 valueScopes.pop_back(); 2147 } 2148 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2149 regionStack.pop_back(); 2150 return success(); 2151 } 2152 2153 FailureOr<Operation *> 2154 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2155 RegionReadState &readState, 2156 bool &isIsolatedFromAbove) { 2157 // Parse the name of the operation. 2158 std::optional<bool> wasRegistered; 2159 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2160 if (failed(opName)) 2161 return failure(); 2162 2163 // Parse the operation mask, which indicates which components of the operation 2164 // are present. 2165 uint8_t opMask; 2166 if (failed(reader.parseByte(opMask))) 2167 return failure(); 2168 2169 /// Parse the location. 2170 LocationAttr opLoc; 2171 if (failed(parseAttribute(reader, opLoc))) 2172 return failure(); 2173 2174 // With the location and name resolved, we can start building the operation 2175 // state. 2176 OperationState opState(opLoc, *opName); 2177 2178 // Parse the attributes of the operation. 2179 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2180 DictionaryAttr dictAttr; 2181 if (failed(parseAttribute(reader, dictAttr))) 2182 return failure(); 2183 opState.attributes = dictAttr; 2184 } 2185 2186 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2187 // kHasProperties wasn't emitted in older bytecode, we should never get 2188 // there without also having the `wasRegistered` flag available. 2189 if (!wasRegistered) 2190 return emitError(fileLoc, 2191 "Unexpected missing `wasRegistered` opname flag at " 2192 "bytecode version ") 2193 << version << " with properties."; 2194 // When an operation is emitted without being registered, the properties are 2195 // stored as an attribute. Otherwise the op must implement the bytecode 2196 // interface and control the serialization. 2197 if (wasRegistered) { 2198 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2199 reader, version); 2200 if (failed( 2201 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2202 return failure(); 2203 } else { 2204 // If the operation wasn't registered when it was emitted, the properties 2205 // was serialized as an attribute. 2206 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2207 return failure(); 2208 } 2209 } 2210 2211 /// Parse the results of the operation. 2212 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2213 uint64_t numResults; 2214 if (failed(reader.parseVarInt(numResults))) 2215 return failure(); 2216 opState.types.resize(numResults); 2217 for (int i = 0, e = numResults; i < e; ++i) 2218 if (failed(parseType(reader, opState.types[i]))) 2219 return failure(); 2220 } 2221 2222 /// Parse the operands of the operation. 2223 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2224 uint64_t numOperands; 2225 if (failed(reader.parseVarInt(numOperands))) 2226 return failure(); 2227 opState.operands.resize(numOperands); 2228 for (int i = 0, e = numOperands; i < e; ++i) 2229 if (!(opState.operands[i] = parseOperand(reader))) 2230 return failure(); 2231 } 2232 2233 /// Parse the successors of the operation. 2234 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2235 uint64_t numSuccs; 2236 if (failed(reader.parseVarInt(numSuccs))) 2237 return failure(); 2238 opState.successors.resize(numSuccs); 2239 for (int i = 0, e = numSuccs; i < e; ++i) { 2240 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2241 "successor"))) 2242 return failure(); 2243 } 2244 } 2245 2246 /// Parse the use-list orders for the results of the operation. Use-list 2247 /// orders are available since version 3 of the bytecode. 2248 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2249 if (version >= bytecode::kUseListOrdering && 2250 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2251 size_t numResults = opState.types.size(); 2252 auto parseResult = parseUseListOrderForRange(reader, numResults); 2253 if (failed(parseResult)) 2254 return failure(); 2255 resultIdxToUseListMap = std::move(*parseResult); 2256 } 2257 2258 /// Parse the regions of the operation. 2259 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2260 uint64_t numRegions; 2261 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2262 return failure(); 2263 2264 opState.regions.reserve(numRegions); 2265 for (int i = 0, e = numRegions; i < e; ++i) 2266 opState.regions.push_back(std::make_unique<Region>()); 2267 } 2268 2269 // Create the operation at the back of the current block. 2270 Operation *op = Operation::create(opState); 2271 readState.curBlock->push_back(op); 2272 2273 // If the operation had results, update the value references. 2274 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2275 return failure(); 2276 2277 /// Store a map for every value that received a custom use-list order from the 2278 /// bytecode file. 2279 if (resultIdxToUseListMap.has_value()) { 2280 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2281 if (resultIdxToUseListMap->contains(idx)) { 2282 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2283 resultIdxToUseListMap->at(idx)); 2284 } 2285 } 2286 } 2287 return op; 2288 } 2289 2290 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2291 EncodingReader &reader = *readState.reader; 2292 2293 // Parse the number of blocks in the region. 2294 uint64_t numBlocks; 2295 if (failed(reader.parseVarInt(numBlocks))) 2296 return failure(); 2297 2298 // If the region is empty, there is nothing else to do. 2299 if (numBlocks == 0) 2300 return success(); 2301 2302 // Parse the number of values defined in this region. 2303 uint64_t numValues; 2304 if (failed(reader.parseVarInt(numValues))) 2305 return failure(); 2306 readState.numValues = numValues; 2307 2308 // Create the blocks within this region. We do this before processing so that 2309 // we can rely on the blocks existing when creating operations. 2310 readState.curBlocks.clear(); 2311 readState.curBlocks.reserve(numBlocks); 2312 for (uint64_t i = 0; i < numBlocks; ++i) { 2313 readState.curBlocks.push_back(new Block()); 2314 readState.curRegion->push_back(readState.curBlocks.back()); 2315 } 2316 2317 // Prepare the current value scope for this region. 2318 valueScopes.back().push(readState); 2319 2320 // Parse the entry block of the region. 2321 readState.curBlock = readState.curRegion->begin(); 2322 return parseBlockHeader(reader, readState); 2323 } 2324 2325 LogicalResult 2326 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2327 RegionReadState &readState) { 2328 bool hasArgs; 2329 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2330 return failure(); 2331 2332 // Parse the arguments of the block. 2333 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2334 return failure(); 2335 2336 // Uselist orders are available since version 3 of the bytecode. 2337 if (version < bytecode::kUseListOrdering) 2338 return success(); 2339 2340 uint8_t hasUseListOrders = 0; 2341 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2342 return failure(); 2343 2344 if (!hasUseListOrders) 2345 return success(); 2346 2347 Block &blk = *readState.curBlock; 2348 auto argIdxToUseListMap = 2349 parseUseListOrderForRange(reader, blk.getNumArguments()); 2350 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2351 return failure(); 2352 2353 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2354 if (argIdxToUseListMap->contains(idx)) 2355 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2356 argIdxToUseListMap->at(idx)); 2357 2358 // We don't parse the operations of the block here, that's done elsewhere. 2359 return success(); 2360 } 2361 2362 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2363 Block *block) { 2364 // Parse the value ID for the first argument, and the number of arguments. 2365 uint64_t numArgs; 2366 if (failed(reader.parseVarInt(numArgs))) 2367 return failure(); 2368 2369 SmallVector<Type> argTypes; 2370 SmallVector<Location> argLocs; 2371 argTypes.reserve(numArgs); 2372 argLocs.reserve(numArgs); 2373 2374 Location unknownLoc = UnknownLoc::get(config.getContext()); 2375 while (numArgs--) { 2376 Type argType; 2377 LocationAttr argLoc = unknownLoc; 2378 if (version >= bytecode::kElideUnknownBlockArgLocation) { 2379 // Parse the type with hasLoc flag to determine if it has type. 2380 uint64_t typeIdx; 2381 bool hasLoc; 2382 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2383 !(argType = attrTypeReader.resolveType(typeIdx))) 2384 return failure(); 2385 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2386 return failure(); 2387 } else { 2388 // All args has type and location. 2389 if (failed(parseType(reader, argType)) || 2390 failed(parseAttribute(reader, argLoc))) 2391 return failure(); 2392 } 2393 argTypes.push_back(argType); 2394 argLocs.push_back(argLoc); 2395 } 2396 block->addArguments(argTypes, argLocs); 2397 return defineValues(reader, block->getArguments()); 2398 } 2399 2400 //===----------------------------------------------------------------------===// 2401 // Value Processing 2402 2403 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2404 std::vector<Value> &values = valueScopes.back().values; 2405 Value *value = nullptr; 2406 if (failed(parseEntry(reader, values, value, "value"))) 2407 return Value(); 2408 2409 // Create a new forward reference if necessary. 2410 if (!*value) 2411 *value = createForwardRef(); 2412 return *value; 2413 } 2414 2415 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2416 ValueRange newValues) { 2417 ValueScope &valueScope = valueScopes.back(); 2418 std::vector<Value> &values = valueScope.values; 2419 2420 unsigned &valueID = valueScope.nextValueIDs.back(); 2421 unsigned valueIDEnd = valueID + newValues.size(); 2422 if (valueIDEnd > values.size()) { 2423 return reader.emitError( 2424 "value index range was outside of the expected range for " 2425 "the parent region, got [", 2426 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2427 values.size() - 1); 2428 } 2429 2430 // Assign the values and update any forward references. 2431 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2432 Value newValue = newValues[i]; 2433 2434 // Check to see if a definition for this value already exists. 2435 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2436 Operation *forwardRefOp = oldValue.getDefiningOp(); 2437 2438 // Assert that this is a forward reference operation. Given how we compute 2439 // definition ids (incrementally as we parse), it shouldn't be possible 2440 // for the value to be defined any other way. 2441 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2442 "value index was already defined?"); 2443 2444 oldValue.replaceAllUsesWith(newValue); 2445 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2446 } 2447 } 2448 return success(); 2449 } 2450 2451 Value BytecodeReader::Impl::createForwardRef() { 2452 // Check for an avaliable existing operation to use. Otherwise, create a new 2453 // fake operation to use for the reference. 2454 if (!openForwardRefOps.empty()) { 2455 Operation *op = &openForwardRefOps.back(); 2456 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2457 } else { 2458 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2459 } 2460 return forwardRefOps.back().getResult(0); 2461 } 2462 2463 //===----------------------------------------------------------------------===// 2464 // Entry Points 2465 //===----------------------------------------------------------------------===// 2466 2467 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2468 2469 BytecodeReader::BytecodeReader( 2470 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2471 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2472 Location sourceFileLoc = 2473 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2474 /*line=*/0, /*column=*/0); 2475 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2476 bufferOwnerRef); 2477 } 2478 2479 LogicalResult BytecodeReader::readTopLevel( 2480 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2481 return impl->read(block, lazyOpsCallback); 2482 } 2483 2484 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2485 return impl->getNumOpsToMaterialize(); 2486 } 2487 2488 bool BytecodeReader::isMaterializable(Operation *op) { 2489 return impl->isMaterializable(op); 2490 } 2491 2492 LogicalResult BytecodeReader::materialize( 2493 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2494 return impl->materialize(op, lazyOpsCallback); 2495 } 2496 2497 LogicalResult 2498 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2499 return impl->finalize(shouldMaterialize); 2500 } 2501 2502 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2503 return buffer.getBuffer().startswith("ML\xefR"); 2504 } 2505 2506 /// Read the bytecode from the provided memory buffer reference. 2507 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2508 /// and may be used to extend the lifetime of the buffer. 2509 static LogicalResult 2510 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2511 const ParserConfig &config, 2512 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2513 Location sourceFileLoc = 2514 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2515 /*line=*/0, /*column=*/0); 2516 if (!isBytecode(buffer)) { 2517 return emitError(sourceFileLoc, 2518 "input buffer is not an MLIR bytecode file"); 2519 } 2520 2521 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2522 buffer, bufferOwnerRef); 2523 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2524 } 2525 2526 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2527 const ParserConfig &config) { 2528 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2529 } 2530 LogicalResult 2531 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2532 Block *block, const ParserConfig &config) { 2533 return readBytecodeFileImpl( 2534 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2535 sourceMgr); 2536 } 2537