1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeReader.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/Verifier.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Support/LLVM.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/ScopeExit.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/Support/Endian.h" 25 #include "llvm/Support/MemoryBufferRef.h" 26 #include "llvm/Support/SourceMgr.h" 27 28 #include <cstddef> 29 #include <list> 30 #include <memory> 31 #include <numeric> 32 #include <optional> 33 34 #define DEBUG_TYPE "mlir-bytecode-reader" 35 36 using namespace mlir; 37 38 /// Stringify the given section ID. 39 static std::string toString(bytecode::Section::ID sectionID) { 40 switch (sectionID) { 41 case bytecode::Section::kString: 42 return "String (0)"; 43 case bytecode::Section::kDialect: 44 return "Dialect (1)"; 45 case bytecode::Section::kAttrType: 46 return "AttrType (2)"; 47 case bytecode::Section::kAttrTypeOffset: 48 return "AttrTypeOffset (3)"; 49 case bytecode::Section::kIR: 50 return "IR (4)"; 51 case bytecode::Section::kResource: 52 return "Resource (5)"; 53 case bytecode::Section::kResourceOffset: 54 return "ResourceOffset (6)"; 55 case bytecode::Section::kDialectVersions: 56 return "DialectVersions (7)"; 57 case bytecode::Section::kProperties: 58 return "Properties (8)"; 59 default: 60 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 61 } 62 } 63 64 /// Returns true if the given top-level section ID is optional. 65 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 66 switch (sectionID) { 67 case bytecode::Section::kString: 68 case bytecode::Section::kDialect: 69 case bytecode::Section::kAttrType: 70 case bytecode::Section::kAttrTypeOffset: 71 case bytecode::Section::kIR: 72 return false; 73 case bytecode::Section::kResource: 74 case bytecode::Section::kResourceOffset: 75 case bytecode::Section::kDialectVersions: 76 return true; 77 case bytecode::Section::kProperties: 78 return version < bytecode::kNativePropertiesEncoding; 79 default: 80 llvm_unreachable("unknown section ID"); 81 } 82 } 83 84 //===----------------------------------------------------------------------===// 85 // EncodingReader 86 //===----------------------------------------------------------------------===// 87 88 namespace { 89 class EncodingReader { 90 public: 91 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 92 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {} 93 explicit EncodingReader(StringRef contents, Location fileLoc) 94 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 95 contents.size()}, 96 fileLoc) {} 97 98 /// Returns true if the entire section has been read. 99 bool empty() const { return dataIt == buffer.end(); } 100 101 /// Returns the remaining size of the bytecode. 102 size_t size() const { return buffer.end() - dataIt; } 103 104 /// Align the current reader position to the specified alignment. 105 LogicalResult alignTo(unsigned alignment) { 106 if (!llvm::isPowerOf2_32(alignment)) 107 return emitError("expected alignment to be a power-of-two"); 108 109 auto isUnaligned = [&](const uint8_t *ptr) { 110 return ((uintptr_t)ptr & (alignment - 1)) != 0; 111 }; 112 113 // Shift the reader position to the next alignment boundary. 114 while (isUnaligned(dataIt)) { 115 uint8_t padding; 116 if (failed(parseByte(padding))) 117 return failure(); 118 if (padding != bytecode::kAlignmentByte) { 119 return emitError("expected alignment byte (0xCB), but got: '0x" + 120 llvm::utohexstr(padding) + "'"); 121 } 122 } 123 124 // Ensure the data iterator is now aligned. This case is unlikely because we 125 // *just* went through the effort to align the data iterator. 126 if (LLVM_UNLIKELY(isUnaligned(dataIt))) { 127 return emitError("expected data iterator aligned to ", alignment, 128 ", but got pointer: '0x" + 129 llvm::utohexstr((uintptr_t)dataIt) + "'"); 130 } 131 132 return success(); 133 } 134 135 /// Emit an error using the given arguments. 136 template <typename... Args> 137 InFlightDiagnostic emitError(Args &&...args) const { 138 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 139 } 140 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 141 142 /// Parse a single byte from the stream. 143 template <typename T> 144 LogicalResult parseByte(T &value) { 145 if (empty()) 146 return emitError("attempting to parse a byte at the end of the bytecode"); 147 value = static_cast<T>(*dataIt++); 148 return success(); 149 } 150 /// Parse a range of bytes of 'length' into the given result. 151 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 152 if (length > size()) { 153 return emitError("attempting to parse ", length, " bytes when only ", 154 size(), " remain"); 155 } 156 result = {dataIt, length}; 157 dataIt += length; 158 return success(); 159 } 160 /// Parse a range of bytes of 'length' into the given result, which can be 161 /// assumed to be large enough to hold `length`. 162 LogicalResult parseBytes(size_t length, uint8_t *result) { 163 if (length > size()) { 164 return emitError("attempting to parse ", length, " bytes when only ", 165 size(), " remain"); 166 } 167 memcpy(result, dataIt, length); 168 dataIt += length; 169 return success(); 170 } 171 172 /// Parse an aligned blob of data, where the alignment was encoded alongside 173 /// the data. 174 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 175 uint64_t &alignment) { 176 uint64_t dataSize; 177 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 178 failed(alignTo(alignment))) 179 return failure(); 180 return parseBytes(dataSize, data); 181 } 182 183 /// Parse a variable length encoded integer from the byte stream. The first 184 /// encoded byte contains a prefix in the low bits indicating the encoded 185 /// length of the value. This length prefix is a bit sequence of '0's followed 186 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 187 /// (not including the prefix byte). All remaining bits in the first byte, 188 /// along with all of the bits in additional bytes, provide the value of the 189 /// integer encoded in little-endian order. 190 LogicalResult parseVarInt(uint64_t &result) { 191 // Parse the first byte of the encoding, which contains the length prefix. 192 if (failed(parseByte(result))) 193 return failure(); 194 195 // Handle the overwhelmingly common case where the value is stored in a 196 // single byte. In this case, the first bit is the `1` marker bit. 197 if (LLVM_LIKELY(result & 1)) { 198 result >>= 1; 199 return success(); 200 } 201 202 // Handle the overwhelming uncommon case where the value required all 8 203 // bytes (i.e. a really really big number). In this case, the marker byte is 204 // all zeros: `00000000`. 205 if (LLVM_UNLIKELY(result == 0)) { 206 llvm::support::ulittle64_t resultLE; 207 if (failed(parseBytes(sizeof(resultLE), 208 reinterpret_cast<uint8_t *>(&resultLE)))) 209 return failure(); 210 result = resultLE; 211 return success(); 212 } 213 return parseMultiByteVarInt(result); 214 } 215 216 /// Parse a signed variable length encoded integer from the byte stream. A 217 /// signed varint is encoded as a normal varint with zigzag encoding applied, 218 /// i.e. the low bit of the value is used to indicate the sign. 219 LogicalResult parseSignedVarInt(uint64_t &result) { 220 if (failed(parseVarInt(result))) 221 return failure(); 222 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 223 result = (result >> 1) ^ (~(result & 1) + 1); 224 return success(); 225 } 226 227 /// Parse a variable length encoded integer whose low bit is used to encode an 228 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 229 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 230 if (failed(parseVarInt(result))) 231 return failure(); 232 flag = result & 1; 233 result >>= 1; 234 return success(); 235 } 236 237 /// Skip the first `length` bytes within the reader. 238 LogicalResult skipBytes(size_t length) { 239 if (length > size()) { 240 return emitError("attempting to skip ", length, " bytes when only ", 241 size(), " remain"); 242 } 243 dataIt += length; 244 return success(); 245 } 246 247 /// Parse a null-terminated string into `result` (without including the NUL 248 /// terminator). 249 LogicalResult parseNullTerminatedString(StringRef &result) { 250 const char *startIt = (const char *)dataIt; 251 const char *nulIt = (const char *)memchr(startIt, 0, size()); 252 if (!nulIt) 253 return emitError( 254 "malformed null-terminated string, no null character found"); 255 256 result = StringRef(startIt, nulIt - startIt); 257 dataIt = (const uint8_t *)nulIt + 1; 258 return success(); 259 } 260 261 /// Parse a section header, placing the kind of section in `sectionID` and the 262 /// contents of the section in `sectionData`. 263 LogicalResult parseSection(bytecode::Section::ID §ionID, 264 ArrayRef<uint8_t> §ionData) { 265 uint8_t sectionIDAndHasAlignment; 266 uint64_t length; 267 if (failed(parseByte(sectionIDAndHasAlignment)) || 268 failed(parseVarInt(length))) 269 return failure(); 270 271 // Extract the section ID and whether the section is aligned. The high bit 272 // of the ID is the alignment flag. 273 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 274 0b01111111); 275 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 276 277 // Check that the section is actually valid before trying to process its 278 // data. 279 if (sectionID >= bytecode::Section::kNumSections) 280 return emitError("invalid section ID: ", unsigned(sectionID)); 281 282 // Process the section alignment if present. 283 if (hasAlignment) { 284 uint64_t alignment; 285 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 286 return failure(); 287 } 288 289 // Parse the actual section data. 290 return parseBytes(static_cast<size_t>(length), sectionData); 291 } 292 293 Location getLoc() const { return fileLoc; } 294 295 private: 296 /// Parse a variable length encoded integer from the byte stream. This method 297 /// is a fallback when the number of bytes used to encode the value is greater 298 /// than 1, but less than the max (9). The provided `result` value can be 299 /// assumed to already contain the first byte of the value. 300 /// NOTE: This method is marked noinline to avoid pessimizing the common case 301 /// of single byte encoding. 302 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 303 // Count the number of trailing zeros in the marker byte, this indicates the 304 // number of trailing bytes that are part of the value. We use `uint32_t` 305 // here because we only care about the first byte, and so that be actually 306 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 307 // implementation). 308 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 309 assert(numBytes > 0 && numBytes <= 7 && 310 "unexpected number of trailing zeros in varint encoding"); 311 312 // Parse in the remaining bytes of the value. 313 llvm::support::ulittle64_t resultLE(result); 314 if (failed( 315 parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1))) 316 return failure(); 317 318 // Shift out the low-order bits that were used to mark how the value was 319 // encoded. 320 result = resultLE >> (numBytes + 1); 321 return success(); 322 } 323 324 /// The bytecode buffer. 325 ArrayRef<uint8_t> buffer; 326 327 /// The current iterator within the 'buffer'. 328 const uint8_t *dataIt; 329 330 /// A location for the bytecode used to report errors. 331 Location fileLoc; 332 }; 333 } // namespace 334 335 /// Resolve an index into the given entry list. `entry` may either be a 336 /// reference, in which case it is assigned to the corresponding value in 337 /// `entries`, or a pointer, in which case it is assigned to the address of the 338 /// element in `entries`. 339 template <typename RangeT, typename T> 340 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 341 uint64_t index, T &entry, 342 StringRef entryStr) { 343 if (index >= entries.size()) 344 return reader.emitError("invalid ", entryStr, " index: ", index); 345 346 // If the provided entry is a pointer, resolve to the address of the entry. 347 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 348 entry = entries[index]; 349 else 350 entry = &entries[index]; 351 return success(); 352 } 353 354 /// Parse and resolve an index into the given entry list. 355 template <typename RangeT, typename T> 356 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 357 T &entry, StringRef entryStr) { 358 uint64_t entryIdx; 359 if (failed(reader.parseVarInt(entryIdx))) 360 return failure(); 361 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // StringSectionReader 366 //===----------------------------------------------------------------------===// 367 368 namespace { 369 /// This class is used to read references to the string section from the 370 /// bytecode. 371 class StringSectionReader { 372 public: 373 /// Initialize the string section reader with the given section data. 374 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 375 376 /// Parse a shared string from the string section. The shared string is 377 /// encoded using an index to a corresponding string in the string section. 378 LogicalResult parseString(EncodingReader &reader, StringRef &result) const { 379 return parseEntry(reader, strings, result, "string"); 380 } 381 382 /// Parse a shared string from the string section. The shared string is 383 /// encoded using an index to a corresponding string in the string section. 384 /// This variant parses a flag compressed with the index. 385 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 386 bool &flag) const { 387 uint64_t entryIdx; 388 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 389 return failure(); 390 return parseStringAtIndex(reader, entryIdx, result); 391 } 392 393 /// Parse a shared string from the string section. The shared string is 394 /// encoded using an index to a corresponding string in the string section. 395 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 396 StringRef &result) const { 397 return resolveEntry(reader, strings, index, result, "string"); 398 } 399 400 private: 401 /// The table of strings referenced within the bytecode file. 402 SmallVector<StringRef> strings; 403 }; 404 } // namespace 405 406 LogicalResult StringSectionReader::initialize(Location fileLoc, 407 ArrayRef<uint8_t> sectionData) { 408 EncodingReader stringReader(sectionData, fileLoc); 409 410 // Parse the number of strings in the section. 411 uint64_t numStrings; 412 if (failed(stringReader.parseVarInt(numStrings))) 413 return failure(); 414 strings.resize(numStrings); 415 416 // Parse each of the strings. The sizes of the strings are encoded in reverse 417 // order, so that's the order we populate the table. 418 size_t stringDataEndOffset = sectionData.size(); 419 for (StringRef &string : llvm::reverse(strings)) { 420 uint64_t stringSize; 421 if (failed(stringReader.parseVarInt(stringSize))) 422 return failure(); 423 if (stringDataEndOffset < stringSize) { 424 return stringReader.emitError( 425 "string size exceeds the available data size"); 426 } 427 428 // Extract the string from the data, dropping the null character. 429 size_t stringOffset = stringDataEndOffset - stringSize; 430 string = StringRef( 431 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 432 stringSize - 1); 433 stringDataEndOffset = stringOffset; 434 } 435 436 // Check that the only remaining data was for the strings, i.e. the reader 437 // should be at the same offset as the first string. 438 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 439 return stringReader.emitError("unexpected trailing data between the " 440 "offsets for strings and their data"); 441 } 442 return success(); 443 } 444 445 //===----------------------------------------------------------------------===// 446 // BytecodeDialect 447 //===----------------------------------------------------------------------===// 448 449 namespace { 450 class DialectReader; 451 452 /// This struct represents a dialect entry within the bytecode. 453 struct BytecodeDialect { 454 /// Load the dialect into the provided context if it hasn't been loaded yet. 455 /// Returns failure if the dialect couldn't be loaded *and* the provided 456 /// context does not allow unregistered dialects. The provided reader is used 457 /// for error emission if necessary. 458 LogicalResult load(const DialectReader &reader, MLIRContext *ctx); 459 460 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 461 /// only be called after `load`. 462 Dialect *getLoadedDialect() const { 463 assert(dialect && 464 "expected `load` to be invoked before `getLoadedDialect`"); 465 return *dialect; 466 } 467 468 /// The loaded dialect entry. This field is std::nullopt if we haven't 469 /// attempted to load, nullptr if we failed to load, otherwise the loaded 470 /// dialect. 471 std::optional<Dialect *> dialect; 472 473 /// The bytecode interface of the dialect, or nullptr if the dialect does not 474 /// implement the bytecode interface. This field should only be checked if the 475 /// `dialect` field is not std::nullopt. 476 const BytecodeDialectInterface *interface = nullptr; 477 478 /// The name of the dialect. 479 StringRef name; 480 481 /// A buffer containing the encoding of the dialect version parsed. 482 ArrayRef<uint8_t> versionBuffer; 483 484 /// Lazy loaded dialect version from the handle above. 485 std::unique_ptr<DialectVersion> loadedVersion; 486 }; 487 488 /// This struct represents an operation name entry within the bytecode. 489 struct BytecodeOperationName { 490 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 491 std::optional<bool> wasRegistered) 492 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 493 494 /// The loaded operation name, or std::nullopt if it hasn't been processed 495 /// yet. 496 std::optional<OperationName> opName; 497 498 /// The dialect that owns this operation name. 499 BytecodeDialect *dialect; 500 501 /// The name of the operation, without the dialect prefix. 502 StringRef name; 503 504 /// Whether this operation was registered when the bytecode was produced. 505 /// This flag is populated when bytecode version >=kNativePropertiesEncoding. 506 std::optional<bool> wasRegistered; 507 }; 508 } // namespace 509 510 /// Parse a single dialect group encoded in the byte stream. 511 static LogicalResult parseDialectGrouping( 512 EncodingReader &reader, 513 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 514 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 515 // Parse the dialect and the number of entries in the group. 516 std::unique_ptr<BytecodeDialect> *dialect; 517 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 518 return failure(); 519 uint64_t numEntries; 520 if (failed(reader.parseVarInt(numEntries))) 521 return failure(); 522 523 for (uint64_t i = 0; i < numEntries; ++i) 524 if (failed(entryCallback(dialect->get()))) 525 return failure(); 526 return success(); 527 } 528 529 //===----------------------------------------------------------------------===// 530 // ResourceSectionReader 531 //===----------------------------------------------------------------------===// 532 533 namespace { 534 /// This class is used to read the resource section from the bytecode. 535 class ResourceSectionReader { 536 public: 537 /// Initialize the resource section reader with the given section data. 538 LogicalResult 539 initialize(Location fileLoc, const ParserConfig &config, 540 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 541 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 542 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 543 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 544 545 /// Parse a dialect resource handle from the resource section. 546 LogicalResult parseResourceHandle(EncodingReader &reader, 547 AsmDialectResourceHandle &result) const { 548 return parseEntry(reader, dialectResources, result, "resource handle"); 549 } 550 551 private: 552 /// The table of dialect resources within the bytecode file. 553 SmallVector<AsmDialectResourceHandle> dialectResources; 554 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 555 }; 556 557 class ParsedResourceEntry : public AsmParsedResourceEntry { 558 public: 559 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 560 EncodingReader &reader, StringSectionReader &stringReader, 561 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 562 : key(key), kind(kind), reader(reader), stringReader(stringReader), 563 bufferOwnerRef(bufferOwnerRef) {} 564 ~ParsedResourceEntry() override = default; 565 566 StringRef getKey() const final { return key; } 567 568 InFlightDiagnostic emitError() const final { return reader.emitError(); } 569 570 AsmResourceEntryKind getKind() const final { return kind; } 571 572 FailureOr<bool> parseAsBool() const final { 573 if (kind != AsmResourceEntryKind::Bool) 574 return emitError() << "expected a bool resource entry, but found a " 575 << toString(kind) << " entry instead"; 576 577 bool value; 578 if (failed(reader.parseByte(value))) 579 return failure(); 580 return value; 581 } 582 FailureOr<std::string> parseAsString() const final { 583 if (kind != AsmResourceEntryKind::String) 584 return emitError() << "expected a string resource entry, but found a " 585 << toString(kind) << " entry instead"; 586 587 StringRef string; 588 if (failed(stringReader.parseString(reader, string))) 589 return failure(); 590 return string.str(); 591 } 592 593 FailureOr<AsmResourceBlob> 594 parseAsBlob(BlobAllocatorFn allocator) const final { 595 if (kind != AsmResourceEntryKind::Blob) 596 return emitError() << "expected a blob resource entry, but found a " 597 << toString(kind) << " entry instead"; 598 599 ArrayRef<uint8_t> data; 600 uint64_t alignment; 601 if (failed(reader.parseBlobAndAlignment(data, alignment))) 602 return failure(); 603 604 // If we have an extendable reference to the buffer owner, we don't need to 605 // allocate a new buffer for the data, and can use the data directly. 606 if (bufferOwnerRef) { 607 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 608 data.size()); 609 610 // Allocate an unmanager buffer which captures a reference to the owner. 611 // For now we just mark this as immutable, but in the future we should 612 // explore marking this as mutable when desired. 613 return UnmanagedAsmResourceBlob::allocateWithAlign( 614 charData, alignment, 615 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 616 } 617 618 // Allocate memory for the blob using the provided allocator and copy the 619 // data into it. 620 AsmResourceBlob blob = allocator(data.size(), alignment); 621 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 622 blob.isMutable() && 623 "blob allocator did not return a properly aligned address"); 624 memcpy(blob.getMutableData().data(), data.data(), data.size()); 625 return blob; 626 } 627 628 private: 629 StringRef key; 630 AsmResourceEntryKind kind; 631 EncodingReader &reader; 632 StringSectionReader &stringReader; 633 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 634 }; 635 } // namespace 636 637 template <typename T> 638 static LogicalResult 639 parseResourceGroup(Location fileLoc, bool allowEmpty, 640 EncodingReader &offsetReader, EncodingReader &resourceReader, 641 StringSectionReader &stringReader, T *handler, 642 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 643 function_ref<StringRef(StringRef)> remapKey = {}, 644 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 645 uint64_t numResources; 646 if (failed(offsetReader.parseVarInt(numResources))) 647 return failure(); 648 649 for (uint64_t i = 0; i < numResources; ++i) { 650 StringRef key; 651 AsmResourceEntryKind kind; 652 uint64_t resourceOffset; 653 ArrayRef<uint8_t> data; 654 if (failed(stringReader.parseString(offsetReader, key)) || 655 failed(offsetReader.parseVarInt(resourceOffset)) || 656 failed(offsetReader.parseByte(kind)) || 657 failed(resourceReader.parseBytes(resourceOffset, data))) 658 return failure(); 659 660 // Process the resource key. 661 if ((processKeyFn && failed(processKeyFn(key)))) 662 return failure(); 663 664 // If the resource data is empty and we allow it, don't error out when 665 // parsing below, just skip it. 666 if (allowEmpty && data.empty()) 667 continue; 668 669 // Ignore the entry if we don't have a valid handler. 670 if (!handler) 671 continue; 672 673 // Otherwise, parse the resource value. 674 EncodingReader entryReader(data, fileLoc); 675 key = remapKey(key); 676 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 677 bufferOwnerRef); 678 if (failed(handler->parseResource(entry))) 679 return failure(); 680 if (!entryReader.empty()) { 681 return entryReader.emitError( 682 "unexpected trailing bytes in resource entry '", key, "'"); 683 } 684 } 685 return success(); 686 } 687 688 LogicalResult ResourceSectionReader::initialize( 689 Location fileLoc, const ParserConfig &config, 690 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 691 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 692 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 693 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 694 EncodingReader resourceReader(sectionData, fileLoc); 695 EncodingReader offsetReader(offsetSectionData, fileLoc); 696 697 // Read the number of external resource providers. 698 uint64_t numExternalResourceGroups; 699 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 700 return failure(); 701 702 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 703 // provides most of the arguments. 704 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 705 function_ref<LogicalResult(StringRef)> keyFn = {}) { 706 auto resolveKey = [&](StringRef key) -> StringRef { 707 auto it = dialectResourceHandleRenamingMap.find(key); 708 if (it == dialectResourceHandleRenamingMap.end()) 709 return key; 710 return it->second; 711 }; 712 713 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 714 stringReader, handler, bufferOwnerRef, resolveKey, 715 keyFn); 716 }; 717 718 // Read the external resources from the bytecode. 719 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 720 StringRef key; 721 if (failed(stringReader.parseString(offsetReader, key))) 722 return failure(); 723 724 // Get the handler for these resources. 725 // TODO: Should we require handling external resources in some scenarios? 726 AsmResourceParser *handler = config.getResourceParser(key); 727 if (!handler) { 728 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 729 << "'"; 730 } 731 732 if (failed(parseGroup(handler))) 733 return failure(); 734 } 735 736 // Read the dialect resources from the bytecode. 737 MLIRContext *ctx = fileLoc->getContext(); 738 while (!offsetReader.empty()) { 739 std::unique_ptr<BytecodeDialect> *dialect; 740 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 741 failed((*dialect)->load(dialectReader, ctx))) 742 return failure(); 743 Dialect *loadedDialect = (*dialect)->getLoadedDialect(); 744 if (!loadedDialect) { 745 return resourceReader.emitError() 746 << "dialect '" << (*dialect)->name << "' is unknown"; 747 } 748 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 749 if (!handler) { 750 return resourceReader.emitError() 751 << "unexpected resources for dialect '" << (*dialect)->name << "'"; 752 } 753 754 // Ensure that each resource is declared before being processed. 755 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 756 FailureOr<AsmDialectResourceHandle> handle = 757 handler->declareResource(key); 758 if (failed(handle)) { 759 return resourceReader.emitError() 760 << "unknown 'resource' key '" << key << "' for dialect '" 761 << (*dialect)->name << "'"; 762 } 763 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 764 dialectResources.push_back(*handle); 765 return success(); 766 }; 767 768 // Parse the resources for this dialect. We allow empty resources because we 769 // just treat these as declarations. 770 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 771 return failure(); 772 } 773 774 return success(); 775 } 776 777 //===----------------------------------------------------------------------===// 778 // Attribute/Type Reader 779 //===----------------------------------------------------------------------===// 780 781 namespace { 782 /// This class provides support for reading attribute and type entries from the 783 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 784 /// this reader to manage when to actually parse them from the bytecode. 785 class AttrTypeReader { 786 /// This class represents a single attribute or type entry. 787 template <typename T> 788 struct Entry { 789 /// The entry, or null if it hasn't been resolved yet. 790 T entry = {}; 791 /// The parent dialect of this entry. 792 BytecodeDialect *dialect = nullptr; 793 /// A flag indicating if the entry was encoded using a custom encoding, 794 /// instead of using the textual assembly format. 795 bool hasCustomEncoding = false; 796 /// The raw data of this entry in the bytecode. 797 ArrayRef<uint8_t> data; 798 }; 799 using AttrEntry = Entry<Attribute>; 800 using TypeEntry = Entry<Type>; 801 802 public: 803 AttrTypeReader(const StringSectionReader &stringReader, 804 const ResourceSectionReader &resourceReader, 805 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 806 uint64_t &bytecodeVersion, Location fileLoc, 807 const ParserConfig &config) 808 : stringReader(stringReader), resourceReader(resourceReader), 809 dialectsMap(dialectsMap), fileLoc(fileLoc), 810 bytecodeVersion(bytecodeVersion), parserConfig(config) {} 811 812 /// Initialize the attribute and type information within the reader. 813 LogicalResult 814 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 815 ArrayRef<uint8_t> sectionData, 816 ArrayRef<uint8_t> offsetSectionData); 817 818 /// Resolve the attribute or type at the given index. Returns nullptr on 819 /// failure. 820 Attribute resolveAttribute(size_t index) { 821 return resolveEntry(attributes, index, "Attribute"); 822 } 823 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 824 825 /// Parse a reference to an attribute or type using the given reader. 826 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 827 uint64_t attrIdx; 828 if (failed(reader.parseVarInt(attrIdx))) 829 return failure(); 830 result = resolveAttribute(attrIdx); 831 return success(!!result); 832 } 833 LogicalResult parseOptionalAttribute(EncodingReader &reader, 834 Attribute &result) { 835 uint64_t attrIdx; 836 bool flag; 837 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 838 return failure(); 839 if (!flag) 840 return success(); 841 result = resolveAttribute(attrIdx); 842 return success(!!result); 843 } 844 845 LogicalResult parseType(EncodingReader &reader, Type &result) { 846 uint64_t typeIdx; 847 if (failed(reader.parseVarInt(typeIdx))) 848 return failure(); 849 result = resolveType(typeIdx); 850 return success(!!result); 851 } 852 853 template <typename T> 854 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 855 Attribute baseResult; 856 if (failed(parseAttribute(reader, baseResult))) 857 return failure(); 858 if ((result = dyn_cast<T>(baseResult))) 859 return success(); 860 return reader.emitError("expected attribute of type: ", 861 llvm::getTypeName<T>(), ", but got: ", baseResult); 862 } 863 864 private: 865 /// Resolve the given entry at `index`. 866 template <typename T> 867 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 868 StringRef entryType); 869 870 /// Parse an entry using the given reader that was encoded using the textual 871 /// assembly format. 872 template <typename T> 873 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 874 StringRef entryType); 875 876 /// Parse an entry using the given reader that was encoded using a custom 877 /// bytecode format. 878 template <typename T> 879 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 880 StringRef entryType); 881 882 /// The string section reader used to resolve string references when parsing 883 /// custom encoded attribute/type entries. 884 const StringSectionReader &stringReader; 885 886 /// The resource section reader used to resolve resource references when 887 /// parsing custom encoded attribute/type entries. 888 const ResourceSectionReader &resourceReader; 889 890 /// The map of the loaded dialects used to retrieve dialect information, such 891 /// as the dialect version. 892 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 893 894 /// The set of attribute and type entries. 895 SmallVector<AttrEntry> attributes; 896 SmallVector<TypeEntry> types; 897 898 /// A location used for error emission. 899 Location fileLoc; 900 901 /// Current bytecode version being used. 902 uint64_t &bytecodeVersion; 903 904 /// Reference to the parser configuration. 905 const ParserConfig &parserConfig; 906 }; 907 908 class DialectReader : public DialectBytecodeReader { 909 public: 910 DialectReader(AttrTypeReader &attrTypeReader, 911 const StringSectionReader &stringReader, 912 const ResourceSectionReader &resourceReader, 913 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 914 EncodingReader &reader, uint64_t &bytecodeVersion) 915 : attrTypeReader(attrTypeReader), stringReader(stringReader), 916 resourceReader(resourceReader), dialectsMap(dialectsMap), 917 reader(reader), bytecodeVersion(bytecodeVersion) {} 918 919 InFlightDiagnostic emitError(const Twine &msg) const override { 920 return reader.emitError(msg); 921 } 922 923 FailureOr<const DialectVersion *> 924 getDialectVersion(StringRef dialectName) const override { 925 // First check if the dialect is available in the map. 926 auto dialectEntry = dialectsMap.find(dialectName); 927 if (dialectEntry == dialectsMap.end()) 928 return failure(); 929 // If the dialect was found, try to load it. This will trigger reading the 930 // bytecode version from the version buffer if it wasn't already processed. 931 // Return failure if either of those two actions could not be completed. 932 if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || 933 dialectEntry->getValue()->loadedVersion == nullptr) 934 return failure(); 935 return dialectEntry->getValue()->loadedVersion.get(); 936 } 937 938 MLIRContext *getContext() const override { return getLoc().getContext(); } 939 940 uint64_t getBytecodeVersion() const override { return bytecodeVersion; } 941 942 DialectReader withEncodingReader(EncodingReader &encReader) const { 943 return DialectReader(attrTypeReader, stringReader, resourceReader, 944 dialectsMap, encReader, bytecodeVersion); 945 } 946 947 Location getLoc() const { return reader.getLoc(); } 948 949 //===--------------------------------------------------------------------===// 950 // IR 951 //===--------------------------------------------------------------------===// 952 953 LogicalResult readAttribute(Attribute &result) override { 954 return attrTypeReader.parseAttribute(reader, result); 955 } 956 LogicalResult readOptionalAttribute(Attribute &result) override { 957 return attrTypeReader.parseOptionalAttribute(reader, result); 958 } 959 LogicalResult readType(Type &result) override { 960 return attrTypeReader.parseType(reader, result); 961 } 962 963 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 964 AsmDialectResourceHandle handle; 965 if (failed(resourceReader.parseResourceHandle(reader, handle))) 966 return failure(); 967 return handle; 968 } 969 970 //===--------------------------------------------------------------------===// 971 // Primitives 972 //===--------------------------------------------------------------------===// 973 974 LogicalResult readVarInt(uint64_t &result) override { 975 return reader.parseVarInt(result); 976 } 977 978 LogicalResult readSignedVarInt(int64_t &result) override { 979 uint64_t unsignedResult; 980 if (failed(reader.parseSignedVarInt(unsignedResult))) 981 return failure(); 982 result = static_cast<int64_t>(unsignedResult); 983 return success(); 984 } 985 986 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 987 // Small values are encoded using a single byte. 988 if (bitWidth <= 8) { 989 uint8_t value; 990 if (failed(reader.parseByte(value))) 991 return failure(); 992 return APInt(bitWidth, value); 993 } 994 995 // Large values up to 64 bits are encoded using a single varint. 996 if (bitWidth <= 64) { 997 uint64_t value; 998 if (failed(reader.parseSignedVarInt(value))) 999 return failure(); 1000 return APInt(bitWidth, value); 1001 } 1002 1003 // Otherwise, for really big values we encode the array of active words in 1004 // the value. 1005 uint64_t numActiveWords; 1006 if (failed(reader.parseVarInt(numActiveWords))) 1007 return failure(); 1008 SmallVector<uint64_t, 4> words(numActiveWords); 1009 for (uint64_t i = 0; i < numActiveWords; ++i) 1010 if (failed(reader.parseSignedVarInt(words[i]))) 1011 return failure(); 1012 return APInt(bitWidth, words); 1013 } 1014 1015 FailureOr<APFloat> 1016 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 1017 FailureOr<APInt> intVal = 1018 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 1019 if (failed(intVal)) 1020 return failure(); 1021 return APFloat(semantics, *intVal); 1022 } 1023 1024 LogicalResult readString(StringRef &result) override { 1025 return stringReader.parseString(reader, result); 1026 } 1027 1028 LogicalResult readBlob(ArrayRef<char> &result) override { 1029 uint64_t dataSize; 1030 ArrayRef<uint8_t> data; 1031 if (failed(reader.parseVarInt(dataSize)) || 1032 failed(reader.parseBytes(dataSize, data))) 1033 return failure(); 1034 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 1035 data.size()); 1036 return success(); 1037 } 1038 1039 LogicalResult readBool(bool &result) override { 1040 return reader.parseByte(result); 1041 } 1042 1043 private: 1044 AttrTypeReader &attrTypeReader; 1045 const StringSectionReader &stringReader; 1046 const ResourceSectionReader &resourceReader; 1047 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 1048 EncodingReader &reader; 1049 uint64_t &bytecodeVersion; 1050 }; 1051 1052 /// Wraps the properties section and handles reading properties out of it. 1053 class PropertiesSectionReader { 1054 public: 1055 /// Initialize the properties section reader with the given section data. 1056 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1057 if (sectionData.empty()) 1058 return success(); 1059 EncodingReader propReader(sectionData, fileLoc); 1060 uint64_t count; 1061 if (failed(propReader.parseVarInt(count))) 1062 return failure(); 1063 // Parse the raw properties buffer. 1064 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1065 return failure(); 1066 1067 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1068 offsetTable.reserve(count); 1069 for (auto idx : llvm::seq<int64_t>(0, count)) { 1070 (void)idx; 1071 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1072 ArrayRef<uint8_t> rawProperties; 1073 uint64_t dataSize; 1074 if (failed(offsetsReader.parseVarInt(dataSize)) || 1075 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1076 return failure(); 1077 } 1078 if (!offsetsReader.empty()) 1079 return offsetsReader.emitError() 1080 << "Broken properties section: didn't exhaust the offsets table"; 1081 return success(); 1082 } 1083 1084 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1085 OperationName *opName, OperationState &opState) const { 1086 uint64_t propertiesIdx; 1087 if (failed(dialectReader.readVarInt(propertiesIdx))) 1088 return failure(); 1089 if (propertiesIdx >= offsetTable.size()) 1090 return dialectReader.emitError("Properties idx out-of-bound for ") 1091 << opName->getStringRef(); 1092 size_t propertiesOffset = offsetTable[propertiesIdx]; 1093 if (propertiesIdx >= propertiesBuffers.size()) 1094 return dialectReader.emitError("Properties offset out-of-bound for ") 1095 << opName->getStringRef(); 1096 1097 // Acquire the sub-buffer that represent the requested properties. 1098 ArrayRef<char> rawProperties; 1099 { 1100 // "Seek" to the requested offset by getting a new reader with the right 1101 // sub-buffer. 1102 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1103 fileLoc); 1104 // Properties are stored as a sequence of {size + raw_data}. 1105 if (failed( 1106 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1107 return failure(); 1108 } 1109 // Setup a new reader to read from the `rawProperties` sub-buffer. 1110 EncodingReader reader( 1111 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1112 DialectReader propReader = dialectReader.withEncodingReader(reader); 1113 1114 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1115 if (iface) 1116 return iface->readProperties(propReader, opState); 1117 if (opName->isRegistered()) 1118 return propReader.emitError( 1119 "has properties but missing BytecodeOpInterface for ") 1120 << opName->getStringRef(); 1121 // Unregistered op are storing properties as an attribute. 1122 return propReader.readAttribute(opState.propertiesAttr); 1123 } 1124 1125 private: 1126 /// The properties buffer referenced within the bytecode file. 1127 ArrayRef<uint8_t> propertiesBuffers; 1128 1129 /// Table of offset in the buffer above. 1130 SmallVector<int64_t> offsetTable; 1131 }; 1132 } // namespace 1133 1134 LogicalResult AttrTypeReader::initialize( 1135 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 1136 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { 1137 EncodingReader offsetReader(offsetSectionData, fileLoc); 1138 1139 // Parse the number of attribute and type entries. 1140 uint64_t numAttributes, numTypes; 1141 if (failed(offsetReader.parseVarInt(numAttributes)) || 1142 failed(offsetReader.parseVarInt(numTypes))) 1143 return failure(); 1144 attributes.resize(numAttributes); 1145 types.resize(numTypes); 1146 1147 // A functor used to accumulate the offsets for the entries in the given 1148 // range. 1149 uint64_t currentOffset = 0; 1150 auto parseEntries = [&](auto &&range) { 1151 size_t currentIndex = 0, endIndex = range.size(); 1152 1153 // Parse an individual entry. 1154 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1155 auto &entry = range[currentIndex++]; 1156 1157 uint64_t entrySize; 1158 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1159 entry.hasCustomEncoding))) 1160 return failure(); 1161 1162 // Verify that the offset is actually valid. 1163 if (currentOffset + entrySize > sectionData.size()) { 1164 return offsetReader.emitError( 1165 "Attribute or Type entry offset points past the end of section"); 1166 } 1167 1168 entry.data = sectionData.slice(currentOffset, entrySize); 1169 entry.dialect = dialect; 1170 currentOffset += entrySize; 1171 return success(); 1172 }; 1173 while (currentIndex != endIndex) 1174 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1175 return failure(); 1176 return success(); 1177 }; 1178 1179 // Process each of the attributes, and then the types. 1180 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1181 return failure(); 1182 1183 // Ensure that we read everything from the section. 1184 if (!offsetReader.empty()) { 1185 return offsetReader.emitError( 1186 "unexpected trailing data in the Attribute/Type offset section"); 1187 } 1188 1189 return success(); 1190 } 1191 1192 template <typename T> 1193 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1194 StringRef entryType) { 1195 if (index >= entries.size()) { 1196 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1197 return {}; 1198 } 1199 1200 // If the entry has already been resolved, there is nothing left to do. 1201 Entry<T> &entry = entries[index]; 1202 if (entry.entry) 1203 return entry.entry; 1204 1205 // Parse the entry. 1206 EncodingReader reader(entry.data, fileLoc); 1207 1208 // Parse based on how the entry was encoded. 1209 if (entry.hasCustomEncoding) { 1210 if (failed(parseCustomEntry(entry, reader, entryType))) 1211 return T(); 1212 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1213 return T(); 1214 } 1215 1216 if (!reader.empty()) { 1217 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1218 return T(); 1219 } 1220 return entry.entry; 1221 } 1222 1223 template <typename T> 1224 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1225 StringRef entryType) { 1226 StringRef asmStr; 1227 if (failed(reader.parseNullTerminatedString(asmStr))) 1228 return failure(); 1229 1230 // Invoke the MLIR assembly parser to parse the entry text. 1231 size_t numRead = 0; 1232 MLIRContext *context = fileLoc->getContext(); 1233 if constexpr (std::is_same_v<T, Type>) 1234 result = 1235 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1236 else 1237 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1238 /*isKnownNullTerminated=*/true); 1239 if (!result) 1240 return failure(); 1241 1242 // Ensure there weren't dangling characters after the entry. 1243 if (numRead != asmStr.size()) { 1244 return reader.emitError("trailing characters found after ", entryType, 1245 " assembly format: ", asmStr.drop_front(numRead)); 1246 } 1247 return success(); 1248 } 1249 1250 template <typename T> 1251 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1252 EncodingReader &reader, 1253 StringRef entryType) { 1254 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, 1255 reader, bytecodeVersion); 1256 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1257 return failure(); 1258 1259 if constexpr (std::is_same_v<T, Type>) { 1260 // Try parsing with callbacks first if available. 1261 for (const auto &callback : 1262 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { 1263 if (failed( 1264 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1265 return failure(); 1266 // Early return if parsing was successful. 1267 if (!!entry.entry) 1268 return success(); 1269 1270 // Reset the reader if we failed to parse, so we can fall through the 1271 // other parsing functions. 1272 reader = EncodingReader(entry.data, reader.getLoc()); 1273 } 1274 } else { 1275 // Try parsing with callbacks first if available. 1276 for (const auto &callback : 1277 parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { 1278 if (failed( 1279 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1280 return failure(); 1281 // Early return if parsing was successful. 1282 if (!!entry.entry) 1283 return success(); 1284 1285 // Reset the reader if we failed to parse, so we can fall through the 1286 // other parsing functions. 1287 reader = EncodingReader(entry.data, reader.getLoc()); 1288 } 1289 } 1290 1291 // Ensure that the dialect implements the bytecode interface. 1292 if (!entry.dialect->interface) { 1293 return reader.emitError("dialect '", entry.dialect->name, 1294 "' does not implement the bytecode interface"); 1295 } 1296 1297 if constexpr (std::is_same_v<T, Type>) 1298 entry.entry = entry.dialect->interface->readType(dialectReader); 1299 else 1300 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1301 1302 return success(!!entry.entry); 1303 } 1304 1305 //===----------------------------------------------------------------------===// 1306 // Bytecode Reader 1307 //===----------------------------------------------------------------------===// 1308 1309 /// This class is used to read a bytecode buffer and translate it into MLIR. 1310 class mlir::BytecodeReader::Impl { 1311 struct RegionReadState; 1312 using LazyLoadableOpsInfo = 1313 std::list<std::pair<Operation *, RegionReadState>>; 1314 using LazyLoadableOpsMap = 1315 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1316 1317 public: 1318 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1319 llvm::MemoryBufferRef buffer, 1320 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1321 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1322 attrTypeReader(stringReader, resourceReader, dialectsMap, version, 1323 fileLoc, config), 1324 // Use the builtin unrealized conversion cast operation to represent 1325 // forward references to values that aren't yet defined. 1326 forwardRefOpState(UnknownLoc::get(config.getContext()), 1327 "builtin.unrealized_conversion_cast", ValueRange(), 1328 NoneType::get(config.getContext())), 1329 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1330 1331 /// Read the bytecode defined within `buffer` into the given block. 1332 LogicalResult read(Block *block, 1333 llvm::function_ref<bool(Operation *)> lazyOps); 1334 1335 /// Return the number of ops that haven't been materialized yet. 1336 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1337 1338 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1339 1340 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1341 /// newly found lazy operation. 1342 LogicalResult 1343 materialize(Operation *op, 1344 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1345 this->lazyOpsCallback = lazyOpsCallback; 1346 auto resetlazyOpsCallback = 1347 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1348 auto it = lazyLoadableOpsMap.find(op); 1349 assert(it != lazyLoadableOpsMap.end() && 1350 "materialize called on non-materializable op"); 1351 return materialize(it); 1352 } 1353 1354 /// Materialize all operations. 1355 LogicalResult materializeAll() { 1356 while (!lazyLoadableOpsMap.empty()) { 1357 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1358 return failure(); 1359 } 1360 return success(); 1361 } 1362 1363 /// Finalize the lazy-loading by calling back with every op that hasn't been 1364 /// materialized to let the client decide if the op should be deleted or 1365 /// materialized. The op is materialized if the callback returns true, deleted 1366 /// otherwise. 1367 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1368 while (!lazyLoadableOps.empty()) { 1369 Operation *op = lazyLoadableOps.begin()->first; 1370 if (shouldMaterialize(op)) { 1371 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1372 return failure(); 1373 continue; 1374 } 1375 op->dropAllReferences(); 1376 op->erase(); 1377 lazyLoadableOps.pop_front(); 1378 lazyLoadableOpsMap.erase(op); 1379 } 1380 return success(); 1381 } 1382 1383 private: 1384 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1385 assert(it != lazyLoadableOpsMap.end() && 1386 "materialize called on non-materializable op"); 1387 valueScopes.emplace_back(); 1388 std::vector<RegionReadState> regionStack; 1389 regionStack.push_back(std::move(it->getSecond()->second)); 1390 lazyLoadableOps.erase(it->getSecond()); 1391 lazyLoadableOpsMap.erase(it); 1392 1393 while (!regionStack.empty()) 1394 if (failed(parseRegions(regionStack, regionStack.back()))) 1395 return failure(); 1396 return success(); 1397 } 1398 1399 /// Return the context for this config. 1400 MLIRContext *getContext() const { return config.getContext(); } 1401 1402 /// Parse the bytecode version. 1403 LogicalResult parseVersion(EncodingReader &reader); 1404 1405 //===--------------------------------------------------------------------===// 1406 // Dialect Section 1407 1408 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1409 1410 /// Parse an operation name reference using the given reader, and set the 1411 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1412 /// context where opName was registered. 1413 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1414 std::optional<bool> &wasRegistered); 1415 1416 //===--------------------------------------------------------------------===// 1417 // Attribute/Type Section 1418 1419 /// Parse an attribute or type using the given reader. 1420 template <typename T> 1421 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1422 return attrTypeReader.parseAttribute(reader, result); 1423 } 1424 LogicalResult parseType(EncodingReader &reader, Type &result) { 1425 return attrTypeReader.parseType(reader, result); 1426 } 1427 1428 //===--------------------------------------------------------------------===// 1429 // Resource Section 1430 1431 LogicalResult 1432 parseResourceSection(EncodingReader &reader, 1433 std::optional<ArrayRef<uint8_t>> resourceData, 1434 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1435 1436 //===--------------------------------------------------------------------===// 1437 // IR Section 1438 1439 /// This struct represents the current read state of a range of regions. This 1440 /// struct is used to enable iterative parsing of regions. 1441 struct RegionReadState { 1442 RegionReadState(Operation *op, EncodingReader *reader, 1443 bool isIsolatedFromAbove) 1444 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1445 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1446 bool isIsolatedFromAbove) 1447 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1448 isIsolatedFromAbove(isIsolatedFromAbove) {} 1449 1450 /// The current regions being read. 1451 MutableArrayRef<Region>::iterator curRegion, endRegion; 1452 /// This is the reader to use for this region, this pointer is pointing to 1453 /// the parent region reader unless the current region is IsolatedFromAbove, 1454 /// in which case the pointer is pointing to the `owningReader` which is a 1455 /// section dedicated to the current region. 1456 EncodingReader *reader; 1457 std::unique_ptr<EncodingReader> owningReader; 1458 1459 /// The number of values defined immediately within this region. 1460 unsigned numValues = 0; 1461 1462 /// The current blocks of the region being read. 1463 SmallVector<Block *> curBlocks; 1464 Region::iterator curBlock = {}; 1465 1466 /// The number of operations remaining to be read from the current block 1467 /// being read. 1468 uint64_t numOpsRemaining = 0; 1469 1470 /// A flag indicating if the regions being read are isolated from above. 1471 bool isIsolatedFromAbove = false; 1472 }; 1473 1474 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1475 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1476 RegionReadState &readState); 1477 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1478 RegionReadState &readState, 1479 bool &isIsolatedFromAbove); 1480 1481 LogicalResult parseRegion(RegionReadState &readState); 1482 LogicalResult parseBlockHeader(EncodingReader &reader, 1483 RegionReadState &readState); 1484 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1485 1486 //===--------------------------------------------------------------------===// 1487 // Value Processing 1488 1489 /// Parse an operand reference using the given reader. Returns nullptr in the 1490 /// case of failure. 1491 Value parseOperand(EncodingReader &reader); 1492 1493 /// Sequentially define the given value range. 1494 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1495 1496 /// Create a value to use for a forward reference. 1497 Value createForwardRef(); 1498 1499 //===--------------------------------------------------------------------===// 1500 // Use-list order helpers 1501 1502 /// This struct is a simple storage that contains information required to 1503 /// reorder the use-list of a value with respect to the pre-order traversal 1504 /// ordering. 1505 struct UseListOrderStorage { 1506 UseListOrderStorage(bool isIndexPairEncoding, 1507 SmallVector<unsigned, 4> &&indices) 1508 : indices(std::move(indices)), 1509 isIndexPairEncoding(isIndexPairEncoding){}; 1510 /// The vector containing the information required to reorder the 1511 /// use-list of a value. 1512 SmallVector<unsigned, 4> indices; 1513 1514 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1515 /// indexing, such as `dst = order[src]`. 1516 bool isIndexPairEncoding; 1517 }; 1518 1519 /// Parse use-list order from bytecode for a range of values if available. The 1520 /// range is expected to be either a block argument or an op result range. On 1521 /// success, return a map of the position in the range and the use-list order 1522 /// encoding. The function assumes to know the size of the range it is 1523 /// processing. 1524 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1525 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1526 uint64_t rangeSize); 1527 1528 /// Shuffle the use-chain according to the order parsed. 1529 LogicalResult sortUseListOrder(Value value); 1530 1531 /// Recursively visit all the values defined within topLevelOp and sort the 1532 /// use-list orders according to the indices parsed. 1533 LogicalResult processUseLists(Operation *topLevelOp); 1534 1535 //===--------------------------------------------------------------------===// 1536 // Fields 1537 1538 /// This class represents a single value scope, in which a value scope is 1539 /// delimited by isolated from above regions. 1540 struct ValueScope { 1541 /// Push a new region state onto this scope, reserving enough values for 1542 /// those defined within the current region of the provided state. 1543 void push(RegionReadState &readState) { 1544 nextValueIDs.push_back(values.size()); 1545 values.resize(values.size() + readState.numValues); 1546 } 1547 1548 /// Pop the values defined for the current region within the provided region 1549 /// state. 1550 void pop(RegionReadState &readState) { 1551 values.resize(values.size() - readState.numValues); 1552 nextValueIDs.pop_back(); 1553 } 1554 1555 /// The set of values defined in this scope. 1556 std::vector<Value> values; 1557 1558 /// The ID for the next defined value for each region current being 1559 /// processed in this scope. 1560 SmallVector<unsigned, 4> nextValueIDs; 1561 }; 1562 1563 /// The configuration of the parser. 1564 const ParserConfig &config; 1565 1566 /// A location to use when emitting errors. 1567 Location fileLoc; 1568 1569 /// Flag that indicates if lazyloading is enabled. 1570 bool lazyLoading; 1571 1572 /// Keep track of operations that have been lazy loaded (their regions haven't 1573 /// been materialized), along with the `RegionReadState` that allows to 1574 /// lazy-load the regions nested under the operation. 1575 LazyLoadableOpsInfo lazyLoadableOps; 1576 LazyLoadableOpsMap lazyLoadableOpsMap; 1577 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1578 1579 /// The reader used to process attribute and types within the bytecode. 1580 AttrTypeReader attrTypeReader; 1581 1582 /// The version of the bytecode being read. 1583 uint64_t version = 0; 1584 1585 /// The producer of the bytecode being read. 1586 StringRef producer; 1587 1588 /// The table of IR units referenced within the bytecode file. 1589 SmallVector<std::unique_ptr<BytecodeDialect>> dialects; 1590 llvm::StringMap<BytecodeDialect *> dialectsMap; 1591 SmallVector<BytecodeOperationName> opNames; 1592 1593 /// The reader used to process resources within the bytecode. 1594 ResourceSectionReader resourceReader; 1595 1596 /// Worklist of values with custom use-list orders to process before the end 1597 /// of the parsing. 1598 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1599 1600 /// The table of strings referenced within the bytecode file. 1601 StringSectionReader stringReader; 1602 1603 /// The table of properties referenced by the operation in the bytecode file. 1604 PropertiesSectionReader propertiesReader; 1605 1606 /// The current set of available IR value scopes. 1607 std::vector<ValueScope> valueScopes; 1608 1609 /// The global pre-order operation ordering. 1610 DenseMap<Operation *, unsigned> operationIDs; 1611 1612 /// A block containing the set of operations defined to create forward 1613 /// references. 1614 Block forwardRefOps; 1615 1616 /// A block containing previously created, and no longer used, forward 1617 /// reference operations. 1618 Block openForwardRefOps; 1619 1620 /// An operation state used when instantiating forward references. 1621 OperationState forwardRefOpState; 1622 1623 /// Reference to the input buffer. 1624 llvm::MemoryBufferRef buffer; 1625 1626 /// The optional owning source manager, which when present may be used to 1627 /// extend the lifetime of the input buffer. 1628 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1629 }; 1630 1631 LogicalResult BytecodeReader::Impl::read( 1632 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1633 EncodingReader reader(buffer.getBuffer(), fileLoc); 1634 this->lazyOpsCallback = lazyOpsCallback; 1635 auto resetlazyOpsCallback = 1636 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1637 1638 // Skip over the bytecode header, this should have already been checked. 1639 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1640 return failure(); 1641 // Parse the bytecode version and producer. 1642 if (failed(parseVersion(reader)) || 1643 failed(reader.parseNullTerminatedString(producer))) 1644 return failure(); 1645 1646 // Add a diagnostic handler that attaches a note that includes the original 1647 // producer of the bytecode. 1648 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1649 diag.attachNote() << "in bytecode version " << version 1650 << " produced by: " << producer; 1651 return failure(); 1652 }); 1653 1654 // Parse the raw data for each of the top-level sections of the bytecode. 1655 std::optional<ArrayRef<uint8_t>> 1656 sectionDatas[bytecode::Section::kNumSections]; 1657 while (!reader.empty()) { 1658 // Read the next section from the bytecode. 1659 bytecode::Section::ID sectionID; 1660 ArrayRef<uint8_t> sectionData; 1661 if (failed(reader.parseSection(sectionID, sectionData))) 1662 return failure(); 1663 1664 // Check for duplicate sections, we only expect one instance of each. 1665 if (sectionDatas[sectionID]) { 1666 return reader.emitError("duplicate top-level section: ", 1667 ::toString(sectionID)); 1668 } 1669 sectionDatas[sectionID] = sectionData; 1670 } 1671 // Check that all of the required sections were found. 1672 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1673 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1674 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1675 return reader.emitError("missing data for top-level section: ", 1676 ::toString(sectionID)); 1677 } 1678 } 1679 1680 // Process the string section first. 1681 if (failed(stringReader.initialize( 1682 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1683 return failure(); 1684 1685 // Process the properties section. 1686 if (sectionDatas[bytecode::Section::kProperties] && 1687 failed(propertiesReader.initialize( 1688 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1689 return failure(); 1690 1691 // Process the dialect section. 1692 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1693 return failure(); 1694 1695 // Process the resource section if present. 1696 if (failed(parseResourceSection( 1697 reader, sectionDatas[bytecode::Section::kResource], 1698 sectionDatas[bytecode::Section::kResourceOffset]))) 1699 return failure(); 1700 1701 // Process the attribute and type section. 1702 if (failed(attrTypeReader.initialize( 1703 dialects, *sectionDatas[bytecode::Section::kAttrType], 1704 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1705 return failure(); 1706 1707 // Finally, process the IR section. 1708 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1709 } 1710 1711 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1712 if (failed(reader.parseVarInt(version))) 1713 return failure(); 1714 1715 // Validate the bytecode version. 1716 uint64_t currentVersion = bytecode::kVersion; 1717 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1718 if (version < minSupportedVersion) { 1719 return reader.emitError("bytecode version ", version, 1720 " is older than the current version of ", 1721 currentVersion, ", and upgrade is not supported"); 1722 } 1723 if (version > currentVersion) { 1724 return reader.emitError("bytecode version ", version, 1725 " is newer than the current version ", 1726 currentVersion); 1727 } 1728 // Override any request to lazy-load if the bytecode version is too old. 1729 if (version < bytecode::kLazyLoading) 1730 lazyLoading = false; 1731 return success(); 1732 } 1733 1734 //===----------------------------------------------------------------------===// 1735 // Dialect Section 1736 1737 LogicalResult BytecodeDialect::load(const DialectReader &reader, 1738 MLIRContext *ctx) { 1739 if (dialect) 1740 return success(); 1741 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1742 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1743 return reader.emitError("dialect '") 1744 << name 1745 << "' is unknown. If this is intended, please call " 1746 "allowUnregisteredDialects() on the MLIRContext, or use " 1747 "-allow-unregistered-dialect with the MLIR tool used."; 1748 } 1749 dialect = loadedDialect; 1750 1751 // If the dialect was actually loaded, check to see if it has a bytecode 1752 // interface. 1753 if (loadedDialect) 1754 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1755 if (!versionBuffer.empty()) { 1756 if (!interface) 1757 return reader.emitError("dialect '") 1758 << name 1759 << "' does not implement the bytecode interface, " 1760 "but found a version entry"; 1761 EncodingReader encReader(versionBuffer, reader.getLoc()); 1762 DialectReader versionReader = reader.withEncodingReader(encReader); 1763 loadedVersion = interface->readVersion(versionReader); 1764 if (!loadedVersion) 1765 return failure(); 1766 } 1767 return success(); 1768 } 1769 1770 LogicalResult 1771 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1772 EncodingReader sectionReader(sectionData, fileLoc); 1773 1774 // Parse the number of dialects in the section. 1775 uint64_t numDialects; 1776 if (failed(sectionReader.parseVarInt(numDialects))) 1777 return failure(); 1778 dialects.resize(numDialects); 1779 1780 // Parse each of the dialects. 1781 for (uint64_t i = 0; i < numDialects; ++i) { 1782 dialects[i] = std::make_unique<BytecodeDialect>(); 1783 /// Before version kDialectVersioning, there wasn't any versioning available 1784 /// for dialects, and the entryIdx represent the string itself. 1785 if (version < bytecode::kDialectVersioning) { 1786 if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) 1787 return failure(); 1788 continue; 1789 } 1790 1791 // Parse ID representing dialect and version. 1792 uint64_t dialectNameIdx; 1793 bool versionAvailable; 1794 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1795 versionAvailable))) 1796 return failure(); 1797 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1798 dialects[i]->name))) 1799 return failure(); 1800 if (versionAvailable) { 1801 bytecode::Section::ID sectionID; 1802 if (failed(sectionReader.parseSection(sectionID, 1803 dialects[i]->versionBuffer))) 1804 return failure(); 1805 if (sectionID != bytecode::Section::kDialectVersions) { 1806 emitError(fileLoc, "expected dialect version section"); 1807 return failure(); 1808 } 1809 } 1810 dialectsMap[dialects[i]->name] = dialects[i].get(); 1811 } 1812 1813 // Parse the operation names, which are grouped by dialect. 1814 auto parseOpName = [&](BytecodeDialect *dialect) { 1815 StringRef opName; 1816 std::optional<bool> wasRegistered; 1817 // Prior to version kNativePropertiesEncoding, the information about wheter 1818 // an op was registered or not wasn't encoded. 1819 if (version < bytecode::kNativePropertiesEncoding) { 1820 if (failed(stringReader.parseString(sectionReader, opName))) 1821 return failure(); 1822 } else { 1823 bool wasRegisteredFlag; 1824 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1825 wasRegisteredFlag))) 1826 return failure(); 1827 wasRegistered = wasRegisteredFlag; 1828 } 1829 opNames.emplace_back(dialect, opName, wasRegistered); 1830 return success(); 1831 }; 1832 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation 1833 // where the number of ops are known. 1834 if (version >= bytecode::kElideUnknownBlockArgLocation) { 1835 uint64_t numOps; 1836 if (failed(sectionReader.parseVarInt(numOps))) 1837 return failure(); 1838 opNames.reserve(numOps); 1839 } 1840 while (!sectionReader.empty()) 1841 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1842 return failure(); 1843 return success(); 1844 } 1845 1846 FailureOr<OperationName> 1847 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1848 std::optional<bool> &wasRegistered) { 1849 BytecodeOperationName *opName = nullptr; 1850 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1851 return failure(); 1852 wasRegistered = opName->wasRegistered; 1853 // Check to see if this operation name has already been resolved. If we 1854 // haven't, load the dialect and build the operation name. 1855 if (!opName->opName) { 1856 // If the opName is empty, this is because we use to accept names such as 1857 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1858 // format anymore but for now we'll be backward compatible. This can only 1859 // happen with unregistered dialects. 1860 if (opName->name.empty()) { 1861 opName->opName.emplace(opName->dialect->name, getContext()); 1862 } else { 1863 // Load the dialect and its version. 1864 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1865 dialectsMap, reader, version); 1866 if (failed(opName->dialect->load(dialectReader, getContext()))) 1867 return failure(); 1868 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1869 getContext()); 1870 } 1871 } 1872 return *opName->opName; 1873 } 1874 1875 //===----------------------------------------------------------------------===// 1876 // Resource Section 1877 1878 LogicalResult BytecodeReader::Impl::parseResourceSection( 1879 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1880 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1881 // Ensure both sections are either present or not. 1882 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1883 if (resourceOffsetData) 1884 return emitError(fileLoc, "unexpected resource offset section when " 1885 "resource section is not present"); 1886 return emitError( 1887 fileLoc, 1888 "expected resource offset section when resource section is present"); 1889 } 1890 1891 // If the resource sections are absent, there is nothing to do. 1892 if (!resourceData) 1893 return success(); 1894 1895 // Initialize the resource reader with the resource sections. 1896 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1897 dialectsMap, reader, version); 1898 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1899 *resourceData, *resourceOffsetData, 1900 dialectReader, bufferOwnerRef); 1901 } 1902 1903 //===----------------------------------------------------------------------===// 1904 // UseListOrder Helpers 1905 1906 FailureOr<BytecodeReader::Impl::UseListMapT> 1907 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1908 uint64_t numResults) { 1909 BytecodeReader::Impl::UseListMapT map; 1910 uint64_t numValuesToRead = 1; 1911 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1912 return failure(); 1913 1914 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1915 uint64_t resultIdx = 0; 1916 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1917 return failure(); 1918 1919 uint64_t numValues; 1920 bool indexPairEncoding; 1921 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1922 return failure(); 1923 1924 SmallVector<unsigned, 4> useListOrders; 1925 for (size_t idx = 0; idx < numValues; idx++) { 1926 uint64_t index; 1927 if (failed(reader.parseVarInt(index))) 1928 return failure(); 1929 useListOrders.push_back(index); 1930 } 1931 1932 // Store in a map the result index 1933 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1934 std::move(useListOrders))); 1935 } 1936 1937 return map; 1938 } 1939 1940 /// Sorts each use according to the order specified in the use-list parsed. If 1941 /// the custom use-list is not found, this means that the order needs to be 1942 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1943 /// on the same operation, the order will follow the reverse operand number 1944 /// ordering. 1945 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1946 // Early return for trivial use-lists. 1947 if (value.use_empty() || value.hasOneUse()) 1948 return success(); 1949 1950 bool hasIncomingOrder = 1951 valueToUseListMap.contains(value.getAsOpaquePointer()); 1952 1953 // Compute the current order of the use-list with respect to the global 1954 // ordering. Detect if the order is already sorted while doing so. 1955 bool alreadySorted = true; 1956 auto &firstUse = *value.use_begin(); 1957 uint64_t prevID = 1958 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1959 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1960 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1961 uint64_t currentID = bytecode::getUseID( 1962 item.value(), operationIDs.at(item.value().getOwner())); 1963 alreadySorted &= prevID > currentID; 1964 currentOrder.push_back({item.index(), currentID}); 1965 prevID = currentID; 1966 } 1967 1968 // If the order is already sorted, and there wasn't a custom order to apply 1969 // from the bytecode file, we are done. 1970 if (alreadySorted && !hasIncomingOrder) 1971 return success(); 1972 1973 // If not already sorted, sort the indices of the current order by descending 1974 // useIDs. 1975 if (!alreadySorted) 1976 std::sort( 1977 currentOrder.begin(), currentOrder.end(), 1978 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1979 1980 if (!hasIncomingOrder) { 1981 // If the bytecode file did not contain any custom use-list order, it means 1982 // that the order was descending useID. Hence, shuffle by the first index 1983 // of the `currentOrder` pair. 1984 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1985 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1986 value.shuffleUseList(shuffle); 1987 return success(); 1988 } 1989 1990 // Pull the custom order info from the map. 1991 UseListOrderStorage customOrder = 1992 valueToUseListMap.at(value.getAsOpaquePointer()); 1993 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1994 uint64_t numUses = 1995 std::distance(value.getUses().begin(), value.getUses().end()); 1996 1997 // If the encoding was a pair of indices `(src, dst)` for every permutation, 1998 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 1999 // as identity, and then apply the mapping encoded in the indices. 2000 if (customOrder.isIndexPairEncoding) { 2001 // Return failure if the number of indices was not representing pairs. 2002 if (shuffle.size() & 1) 2003 return failure(); 2004 2005 SmallVector<unsigned, 4> newShuffle(numUses); 2006 size_t idx = 0; 2007 std::iota(newShuffle.begin(), newShuffle.end(), idx); 2008 for (idx = 0; idx < shuffle.size(); idx += 2) 2009 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 2010 2011 shuffle = std::move(newShuffle); 2012 } 2013 2014 // Make sure that the indices represent a valid mapping. That is, the sum of 2015 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 2016 // duplicates are allowed in the list. 2017 DenseSet<unsigned> set; 2018 uint64_t accumulator = 0; 2019 for (const auto &elem : shuffle) { 2020 if (!set.insert(elem).second) 2021 return failure(); 2022 accumulator += elem; 2023 } 2024 if (numUses != shuffle.size() || 2025 accumulator != (((numUses - 1) * numUses) >> 1)) 2026 return failure(); 2027 2028 // Apply the current ordering map onto the shuffle vector to get the final 2029 // use-list sorting indices before shuffling. 2030 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 2031 currentOrder, [&](auto item) { return shuffle[item.first]; })); 2032 value.shuffleUseList(shuffle); 2033 return success(); 2034 } 2035 2036 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 2037 // Precompute operation IDs according to the pre-order walk of the IR. We 2038 // can't do this while parsing since parseRegions ordering is not strictly 2039 // equal to the pre-order walk. 2040 unsigned operationID = 0; 2041 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 2042 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 2043 2044 auto blockWalk = topLevelOp->walk([this](Block *block) { 2045 for (auto arg : block->getArguments()) 2046 if (failed(sortUseListOrder(arg))) 2047 return WalkResult::interrupt(); 2048 return WalkResult::advance(); 2049 }); 2050 2051 auto resultWalk = topLevelOp->walk([this](Operation *op) { 2052 for (auto result : op->getResults()) 2053 if (failed(sortUseListOrder(result))) 2054 return WalkResult::interrupt(); 2055 return WalkResult::advance(); 2056 }); 2057 2058 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 2059 } 2060 2061 //===----------------------------------------------------------------------===// 2062 // IR Section 2063 2064 LogicalResult 2065 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 2066 Block *block) { 2067 EncodingReader reader(sectionData, fileLoc); 2068 2069 // A stack of operation regions currently being read from the bytecode. 2070 std::vector<RegionReadState> regionStack; 2071 2072 // Parse the top-level block using a temporary module operation. 2073 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 2074 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 2075 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 2076 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 2077 if (failed(parseBlockHeader(reader, regionStack.back()))) 2078 return failure(); 2079 valueScopes.emplace_back(); 2080 valueScopes.back().push(regionStack.back()); 2081 2082 // Iteratively parse regions until everything has been resolved. 2083 while (!regionStack.empty()) 2084 if (failed(parseRegions(regionStack, regionStack.back()))) 2085 return failure(); 2086 if (!forwardRefOps.empty()) { 2087 return reader.emitError( 2088 "not all forward unresolved forward operand references"); 2089 } 2090 2091 // Sort use-lists according to what specified in bytecode. 2092 if (failed(processUseLists(*moduleOp))) 2093 return reader.emitError( 2094 "parsed use-list orders were invalid and could not be applied"); 2095 2096 // Resolve dialect version. 2097 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { 2098 // Parsing is complete, give an opportunity to each dialect to visit the 2099 // IR and perform upgrades. 2100 if (!byteCodeDialect->loadedVersion) 2101 continue; 2102 if (byteCodeDialect->interface && 2103 failed(byteCodeDialect->interface->upgradeFromVersion( 2104 *moduleOp, *byteCodeDialect->loadedVersion))) 2105 return failure(); 2106 } 2107 2108 // Verify that the parsed operations are valid. 2109 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2110 return failure(); 2111 2112 // Splice the parsed operations over to the provided top-level block. 2113 auto &parsedOps = moduleOp->getBody()->getOperations(); 2114 auto &destOps = block->getOperations(); 2115 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2116 return success(); 2117 } 2118 2119 LogicalResult 2120 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2121 RegionReadState &readState) { 2122 // Process regions, blocks, and operations until the end or if a nested 2123 // region is encountered. In this case we push a new state in regionStack and 2124 // return, the processing of the current region will resume afterward. 2125 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2126 // If the current block hasn't been setup yet, parse the header for this 2127 // region. The current block is already setup when this function was 2128 // interrupted to recurse down in a nested region and we resume the current 2129 // block after processing the nested region. 2130 if (readState.curBlock == Region::iterator()) { 2131 if (failed(parseRegion(readState))) 2132 return failure(); 2133 2134 // If the region is empty, there is nothing to more to do. 2135 if (readState.curRegion->empty()) 2136 continue; 2137 } 2138 2139 // Parse the blocks within the region. 2140 EncodingReader &reader = *readState.reader; 2141 do { 2142 while (readState.numOpsRemaining--) { 2143 // Read in the next operation. We don't read its regions directly, we 2144 // handle those afterwards as necessary. 2145 bool isIsolatedFromAbove = false; 2146 FailureOr<Operation *> op = 2147 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2148 if (failed(op)) 2149 return failure(); 2150 2151 // If the op has regions, add it to the stack for processing and return: 2152 // we stop the processing of the current region and resume it after the 2153 // inner one is completed. Unless LazyLoading is activated in which case 2154 // nested region parsing is delayed. 2155 if ((*op)->getNumRegions()) { 2156 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2157 2158 // Isolated regions are encoded as a section in version 2 and above. 2159 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { 2160 bytecode::Section::ID sectionID; 2161 ArrayRef<uint8_t> sectionData; 2162 if (failed(reader.parseSection(sectionID, sectionData))) 2163 return failure(); 2164 if (sectionID != bytecode::Section::kIR) 2165 return emitError(fileLoc, "expected IR section for region"); 2166 childState.owningReader = 2167 std::make_unique<EncodingReader>(sectionData, fileLoc); 2168 childState.reader = childState.owningReader.get(); 2169 2170 // If the user has a callback set, they have the opportunity to 2171 // control lazyloading as we go. 2172 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) { 2173 lazyLoadableOps.emplace_back(*op, std::move(childState)); 2174 lazyLoadableOpsMap.try_emplace(*op, 2175 std::prev(lazyLoadableOps.end())); 2176 continue; 2177 } 2178 } 2179 regionStack.push_back(std::move(childState)); 2180 2181 // If the op is isolated from above, push a new value scope. 2182 if (isIsolatedFromAbove) 2183 valueScopes.emplace_back(); 2184 return success(); 2185 } 2186 } 2187 2188 // Move to the next block of the region. 2189 if (++readState.curBlock == readState.curRegion->end()) 2190 break; 2191 if (failed(parseBlockHeader(reader, readState))) 2192 return failure(); 2193 } while (true); 2194 2195 // Reset the current block and any values reserved for this region. 2196 readState.curBlock = {}; 2197 valueScopes.back().pop(readState); 2198 } 2199 2200 // When the regions have been fully parsed, pop them off of the read stack. If 2201 // the regions were isolated from above, we also pop the last value scope. 2202 if (readState.isIsolatedFromAbove) { 2203 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2204 valueScopes.pop_back(); 2205 } 2206 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2207 regionStack.pop_back(); 2208 return success(); 2209 } 2210 2211 FailureOr<Operation *> 2212 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2213 RegionReadState &readState, 2214 bool &isIsolatedFromAbove) { 2215 // Parse the name of the operation. 2216 std::optional<bool> wasRegistered; 2217 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2218 if (failed(opName)) 2219 return failure(); 2220 2221 // Parse the operation mask, which indicates which components of the operation 2222 // are present. 2223 uint8_t opMask; 2224 if (failed(reader.parseByte(opMask))) 2225 return failure(); 2226 2227 /// Parse the location. 2228 LocationAttr opLoc; 2229 if (failed(parseAttribute(reader, opLoc))) 2230 return failure(); 2231 2232 // With the location and name resolved, we can start building the operation 2233 // state. 2234 OperationState opState(opLoc, *opName); 2235 2236 // Parse the attributes of the operation. 2237 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2238 DictionaryAttr dictAttr; 2239 if (failed(parseAttribute(reader, dictAttr))) 2240 return failure(); 2241 opState.attributes = dictAttr; 2242 } 2243 2244 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2245 // kHasProperties wasn't emitted in older bytecode, we should never get 2246 // there without also having the `wasRegistered` flag available. 2247 if (!wasRegistered) 2248 return emitError(fileLoc, 2249 "Unexpected missing `wasRegistered` opname flag at " 2250 "bytecode version ") 2251 << version << " with properties."; 2252 // When an operation is emitted without being registered, the properties are 2253 // stored as an attribute. Otherwise the op must implement the bytecode 2254 // interface and control the serialization. 2255 if (wasRegistered) { 2256 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2257 dialectsMap, reader, version); 2258 if (failed( 2259 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2260 return failure(); 2261 } else { 2262 // If the operation wasn't registered when it was emitted, the properties 2263 // was serialized as an attribute. 2264 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2265 return failure(); 2266 } 2267 } 2268 2269 /// Parse the results of the operation. 2270 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2271 uint64_t numResults; 2272 if (failed(reader.parseVarInt(numResults))) 2273 return failure(); 2274 opState.types.resize(numResults); 2275 for (int i = 0, e = numResults; i < e; ++i) 2276 if (failed(parseType(reader, opState.types[i]))) 2277 return failure(); 2278 } 2279 2280 /// Parse the operands of the operation. 2281 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2282 uint64_t numOperands; 2283 if (failed(reader.parseVarInt(numOperands))) 2284 return failure(); 2285 opState.operands.resize(numOperands); 2286 for (int i = 0, e = numOperands; i < e; ++i) 2287 if (!(opState.operands[i] = parseOperand(reader))) 2288 return failure(); 2289 } 2290 2291 /// Parse the successors of the operation. 2292 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2293 uint64_t numSuccs; 2294 if (failed(reader.parseVarInt(numSuccs))) 2295 return failure(); 2296 opState.successors.resize(numSuccs); 2297 for (int i = 0, e = numSuccs; i < e; ++i) { 2298 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2299 "successor"))) 2300 return failure(); 2301 } 2302 } 2303 2304 /// Parse the use-list orders for the results of the operation. Use-list 2305 /// orders are available since version 3 of the bytecode. 2306 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2307 if (version >= bytecode::kUseListOrdering && 2308 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2309 size_t numResults = opState.types.size(); 2310 auto parseResult = parseUseListOrderForRange(reader, numResults); 2311 if (failed(parseResult)) 2312 return failure(); 2313 resultIdxToUseListMap = std::move(*parseResult); 2314 } 2315 2316 /// Parse the regions of the operation. 2317 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2318 uint64_t numRegions; 2319 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2320 return failure(); 2321 2322 opState.regions.reserve(numRegions); 2323 for (int i = 0, e = numRegions; i < e; ++i) 2324 opState.regions.push_back(std::make_unique<Region>()); 2325 } 2326 2327 // Create the operation at the back of the current block. 2328 Operation *op = Operation::create(opState); 2329 readState.curBlock->push_back(op); 2330 2331 // If the operation had results, update the value references. We don't need to 2332 // do this if the current value scope is empty. That is, the op was not 2333 // encoded within a parent region. 2334 if (readState.numValues && op->getNumResults() && 2335 failed(defineValues(reader, op->getResults()))) 2336 return failure(); 2337 2338 /// Store a map for every value that received a custom use-list order from the 2339 /// bytecode file. 2340 if (resultIdxToUseListMap.has_value()) { 2341 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2342 if (resultIdxToUseListMap->contains(idx)) { 2343 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2344 resultIdxToUseListMap->at(idx)); 2345 } 2346 } 2347 } 2348 return op; 2349 } 2350 2351 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2352 EncodingReader &reader = *readState.reader; 2353 2354 // Parse the number of blocks in the region. 2355 uint64_t numBlocks; 2356 if (failed(reader.parseVarInt(numBlocks))) 2357 return failure(); 2358 2359 // If the region is empty, there is nothing else to do. 2360 if (numBlocks == 0) 2361 return success(); 2362 2363 // Parse the number of values defined in this region. 2364 uint64_t numValues; 2365 if (failed(reader.parseVarInt(numValues))) 2366 return failure(); 2367 readState.numValues = numValues; 2368 2369 // Create the blocks within this region. We do this before processing so that 2370 // we can rely on the blocks existing when creating operations. 2371 readState.curBlocks.clear(); 2372 readState.curBlocks.reserve(numBlocks); 2373 for (uint64_t i = 0; i < numBlocks; ++i) { 2374 readState.curBlocks.push_back(new Block()); 2375 readState.curRegion->push_back(readState.curBlocks.back()); 2376 } 2377 2378 // Prepare the current value scope for this region. 2379 valueScopes.back().push(readState); 2380 2381 // Parse the entry block of the region. 2382 readState.curBlock = readState.curRegion->begin(); 2383 return parseBlockHeader(reader, readState); 2384 } 2385 2386 LogicalResult 2387 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2388 RegionReadState &readState) { 2389 bool hasArgs; 2390 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2391 return failure(); 2392 2393 // Parse the arguments of the block. 2394 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2395 return failure(); 2396 2397 // Uselist orders are available since version 3 of the bytecode. 2398 if (version < bytecode::kUseListOrdering) 2399 return success(); 2400 2401 uint8_t hasUseListOrders = 0; 2402 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2403 return failure(); 2404 2405 if (!hasUseListOrders) 2406 return success(); 2407 2408 Block &blk = *readState.curBlock; 2409 auto argIdxToUseListMap = 2410 parseUseListOrderForRange(reader, blk.getNumArguments()); 2411 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2412 return failure(); 2413 2414 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2415 if (argIdxToUseListMap->contains(idx)) 2416 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2417 argIdxToUseListMap->at(idx)); 2418 2419 // We don't parse the operations of the block here, that's done elsewhere. 2420 return success(); 2421 } 2422 2423 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2424 Block *block) { 2425 // Parse the value ID for the first argument, and the number of arguments. 2426 uint64_t numArgs; 2427 if (failed(reader.parseVarInt(numArgs))) 2428 return failure(); 2429 2430 SmallVector<Type> argTypes; 2431 SmallVector<Location> argLocs; 2432 argTypes.reserve(numArgs); 2433 argLocs.reserve(numArgs); 2434 2435 Location unknownLoc = UnknownLoc::get(config.getContext()); 2436 while (numArgs--) { 2437 Type argType; 2438 LocationAttr argLoc = unknownLoc; 2439 if (version >= bytecode::kElideUnknownBlockArgLocation) { 2440 // Parse the type with hasLoc flag to determine if it has type. 2441 uint64_t typeIdx; 2442 bool hasLoc; 2443 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2444 !(argType = attrTypeReader.resolveType(typeIdx))) 2445 return failure(); 2446 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2447 return failure(); 2448 } else { 2449 // All args has type and location. 2450 if (failed(parseType(reader, argType)) || 2451 failed(parseAttribute(reader, argLoc))) 2452 return failure(); 2453 } 2454 argTypes.push_back(argType); 2455 argLocs.push_back(argLoc); 2456 } 2457 block->addArguments(argTypes, argLocs); 2458 return defineValues(reader, block->getArguments()); 2459 } 2460 2461 //===----------------------------------------------------------------------===// 2462 // Value Processing 2463 2464 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2465 std::vector<Value> &values = valueScopes.back().values; 2466 Value *value = nullptr; 2467 if (failed(parseEntry(reader, values, value, "value"))) 2468 return Value(); 2469 2470 // Create a new forward reference if necessary. 2471 if (!*value) 2472 *value = createForwardRef(); 2473 return *value; 2474 } 2475 2476 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2477 ValueRange newValues) { 2478 ValueScope &valueScope = valueScopes.back(); 2479 std::vector<Value> &values = valueScope.values; 2480 2481 unsigned &valueID = valueScope.nextValueIDs.back(); 2482 unsigned valueIDEnd = valueID + newValues.size(); 2483 if (valueIDEnd > values.size()) { 2484 return reader.emitError( 2485 "value index range was outside of the expected range for " 2486 "the parent region, got [", 2487 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2488 values.size() - 1); 2489 } 2490 2491 // Assign the values and update any forward references. 2492 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2493 Value newValue = newValues[i]; 2494 2495 // Check to see if a definition for this value already exists. 2496 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2497 Operation *forwardRefOp = oldValue.getDefiningOp(); 2498 2499 // Assert that this is a forward reference operation. Given how we compute 2500 // definition ids (incrementally as we parse), it shouldn't be possible 2501 // for the value to be defined any other way. 2502 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2503 "value index was already defined?"); 2504 2505 oldValue.replaceAllUsesWith(newValue); 2506 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2507 } 2508 } 2509 return success(); 2510 } 2511 2512 Value BytecodeReader::Impl::createForwardRef() { 2513 // Check for an available existing operation to use. Otherwise, create a new 2514 // fake operation to use for the reference. 2515 if (!openForwardRefOps.empty()) { 2516 Operation *op = &openForwardRefOps.back(); 2517 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2518 } else { 2519 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2520 } 2521 return forwardRefOps.back().getResult(0); 2522 } 2523 2524 //===----------------------------------------------------------------------===// 2525 // Entry Points 2526 //===----------------------------------------------------------------------===// 2527 2528 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2529 2530 BytecodeReader::BytecodeReader( 2531 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2532 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2533 Location sourceFileLoc = 2534 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2535 /*line=*/0, /*column=*/0); 2536 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2537 bufferOwnerRef); 2538 } 2539 2540 LogicalResult BytecodeReader::readTopLevel( 2541 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2542 return impl->read(block, lazyOpsCallback); 2543 } 2544 2545 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2546 return impl->getNumOpsToMaterialize(); 2547 } 2548 2549 bool BytecodeReader::isMaterializable(Operation *op) { 2550 return impl->isMaterializable(op); 2551 } 2552 2553 LogicalResult BytecodeReader::materialize( 2554 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2555 return impl->materialize(op, lazyOpsCallback); 2556 } 2557 2558 LogicalResult 2559 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2560 return impl->finalize(shouldMaterialize); 2561 } 2562 2563 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2564 return buffer.getBuffer().starts_with("ML\xefR"); 2565 } 2566 2567 /// Read the bytecode from the provided memory buffer reference. 2568 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2569 /// and may be used to extend the lifetime of the buffer. 2570 static LogicalResult 2571 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2572 const ParserConfig &config, 2573 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2574 Location sourceFileLoc = 2575 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2576 /*line=*/0, /*column=*/0); 2577 if (!isBytecode(buffer)) { 2578 return emitError(sourceFileLoc, 2579 "input buffer is not an MLIR bytecode file"); 2580 } 2581 2582 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2583 buffer, bufferOwnerRef); 2584 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2585 } 2586 2587 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2588 const ParserConfig &config) { 2589 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2590 } 2591 LogicalResult 2592 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2593 Block *block, const ParserConfig &config) { 2594 return readBytecodeFileImpl( 2595 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2596 sourceMgr); 2597 } 2598