1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeReader.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/Verifier.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Support/LogicalResult.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/ScopeExit.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/Endian.h" 26 #include "llvm/Support/MemoryBufferRef.h" 27 #include "llvm/Support/SourceMgr.h" 28 29 #include <cstddef> 30 #include <list> 31 #include <memory> 32 #include <numeric> 33 #include <optional> 34 35 #define DEBUG_TYPE "mlir-bytecode-reader" 36 37 using namespace mlir; 38 39 /// Stringify the given section ID. 40 static std::string toString(bytecode::Section::ID sectionID) { 41 switch (sectionID) { 42 case bytecode::Section::kString: 43 return "String (0)"; 44 case bytecode::Section::kDialect: 45 return "Dialect (1)"; 46 case bytecode::Section::kAttrType: 47 return "AttrType (2)"; 48 case bytecode::Section::kAttrTypeOffset: 49 return "AttrTypeOffset (3)"; 50 case bytecode::Section::kIR: 51 return "IR (4)"; 52 case bytecode::Section::kResource: 53 return "Resource (5)"; 54 case bytecode::Section::kResourceOffset: 55 return "ResourceOffset (6)"; 56 case bytecode::Section::kDialectVersions: 57 return "DialectVersions (7)"; 58 case bytecode::Section::kProperties: 59 return "Properties (8)"; 60 default: 61 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 62 } 63 } 64 65 /// Returns true if the given top-level section ID is optional. 66 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 67 switch (sectionID) { 68 case bytecode::Section::kString: 69 case bytecode::Section::kDialect: 70 case bytecode::Section::kAttrType: 71 case bytecode::Section::kAttrTypeOffset: 72 case bytecode::Section::kIR: 73 return false; 74 case bytecode::Section::kResource: 75 case bytecode::Section::kResourceOffset: 76 case bytecode::Section::kDialectVersions: 77 return true; 78 case bytecode::Section::kProperties: 79 return version < bytecode::kNativePropertiesEncoding; 80 default: 81 llvm_unreachable("unknown section ID"); 82 } 83 } 84 85 //===----------------------------------------------------------------------===// 86 // EncodingReader 87 //===----------------------------------------------------------------------===// 88 89 namespace { 90 class EncodingReader { 91 public: 92 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 93 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {} 94 explicit EncodingReader(StringRef contents, Location fileLoc) 95 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 96 contents.size()}, 97 fileLoc) {} 98 99 /// Returns true if the entire section has been read. 100 bool empty() const { return dataIt == buffer.end(); } 101 102 /// Returns the remaining size of the bytecode. 103 size_t size() const { return buffer.end() - dataIt; } 104 105 /// Align the current reader position to the specified alignment. 106 LogicalResult alignTo(unsigned alignment) { 107 if (!llvm::isPowerOf2_32(alignment)) 108 return emitError("expected alignment to be a power-of-two"); 109 110 auto isUnaligned = [&](const uint8_t *ptr) { 111 return ((uintptr_t)ptr & (alignment - 1)) != 0; 112 }; 113 114 // Shift the reader position to the next alignment boundary. 115 while (isUnaligned(dataIt)) { 116 uint8_t padding; 117 if (failed(parseByte(padding))) 118 return failure(); 119 if (padding != bytecode::kAlignmentByte) { 120 return emitError("expected alignment byte (0xCB), but got: '0x" + 121 llvm::utohexstr(padding) + "'"); 122 } 123 } 124 125 // Ensure the data iterator is now aligned. This case is unlikely because we 126 // *just* went through the effort to align the data iterator. 127 if (LLVM_UNLIKELY(isUnaligned(dataIt))) { 128 return emitError("expected data iterator aligned to ", alignment, 129 ", but got pointer: '0x" + 130 llvm::utohexstr((uintptr_t)dataIt) + "'"); 131 } 132 133 return success(); 134 } 135 136 /// Emit an error using the given arguments. 137 template <typename... Args> 138 InFlightDiagnostic emitError(Args &&...args) const { 139 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 140 } 141 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 142 143 /// Parse a single byte from the stream. 144 template <typename T> 145 LogicalResult parseByte(T &value) { 146 if (empty()) 147 return emitError("attempting to parse a byte at the end of the bytecode"); 148 value = static_cast<T>(*dataIt++); 149 return success(); 150 } 151 /// Parse a range of bytes of 'length' into the given result. 152 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 153 if (length > size()) { 154 return emitError("attempting to parse ", length, " bytes when only ", 155 size(), " remain"); 156 } 157 result = {dataIt, length}; 158 dataIt += length; 159 return success(); 160 } 161 /// Parse a range of bytes of 'length' into the given result, which can be 162 /// assumed to be large enough to hold `length`. 163 LogicalResult parseBytes(size_t length, uint8_t *result) { 164 if (length > size()) { 165 return emitError("attempting to parse ", length, " bytes when only ", 166 size(), " remain"); 167 } 168 memcpy(result, dataIt, length); 169 dataIt += length; 170 return success(); 171 } 172 173 /// Parse an aligned blob of data, where the alignment was encoded alongside 174 /// the data. 175 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 176 uint64_t &alignment) { 177 uint64_t dataSize; 178 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 179 failed(alignTo(alignment))) 180 return failure(); 181 return parseBytes(dataSize, data); 182 } 183 184 /// Parse a variable length encoded integer from the byte stream. The first 185 /// encoded byte contains a prefix in the low bits indicating the encoded 186 /// length of the value. This length prefix is a bit sequence of '0's followed 187 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 188 /// (not including the prefix byte). All remaining bits in the first byte, 189 /// along with all of the bits in additional bytes, provide the value of the 190 /// integer encoded in little-endian order. 191 LogicalResult parseVarInt(uint64_t &result) { 192 // Parse the first byte of the encoding, which contains the length prefix. 193 if (failed(parseByte(result))) 194 return failure(); 195 196 // Handle the overwhelmingly common case where the value is stored in a 197 // single byte. In this case, the first bit is the `1` marker bit. 198 if (LLVM_LIKELY(result & 1)) { 199 result >>= 1; 200 return success(); 201 } 202 203 // Handle the overwhelming uncommon case where the value required all 8 204 // bytes (i.e. a really really big number). In this case, the marker byte is 205 // all zeros: `00000000`. 206 if (LLVM_UNLIKELY(result == 0)) { 207 llvm::support::ulittle64_t resultLE; 208 if (failed(parseBytes(sizeof(resultLE), 209 reinterpret_cast<uint8_t *>(&resultLE)))) 210 return failure(); 211 result = resultLE; 212 return success(); 213 } 214 return parseMultiByteVarInt(result); 215 } 216 217 /// Parse a signed variable length encoded integer from the byte stream. A 218 /// signed varint is encoded as a normal varint with zigzag encoding applied, 219 /// i.e. the low bit of the value is used to indicate the sign. 220 LogicalResult parseSignedVarInt(uint64_t &result) { 221 if (failed(parseVarInt(result))) 222 return failure(); 223 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 224 result = (result >> 1) ^ (~(result & 1) + 1); 225 return success(); 226 } 227 228 /// Parse a variable length encoded integer whose low bit is used to encode an 229 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 230 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 231 if (failed(parseVarInt(result))) 232 return failure(); 233 flag = result & 1; 234 result >>= 1; 235 return success(); 236 } 237 238 /// Skip the first `length` bytes within the reader. 239 LogicalResult skipBytes(size_t length) { 240 if (length > size()) { 241 return emitError("attempting to skip ", length, " bytes when only ", 242 size(), " remain"); 243 } 244 dataIt += length; 245 return success(); 246 } 247 248 /// Parse a null-terminated string into `result` (without including the NUL 249 /// terminator). 250 LogicalResult parseNullTerminatedString(StringRef &result) { 251 const char *startIt = (const char *)dataIt; 252 const char *nulIt = (const char *)memchr(startIt, 0, size()); 253 if (!nulIt) 254 return emitError( 255 "malformed null-terminated string, no null character found"); 256 257 result = StringRef(startIt, nulIt - startIt); 258 dataIt = (const uint8_t *)nulIt + 1; 259 return success(); 260 } 261 262 /// Parse a section header, placing the kind of section in `sectionID` and the 263 /// contents of the section in `sectionData`. 264 LogicalResult parseSection(bytecode::Section::ID §ionID, 265 ArrayRef<uint8_t> §ionData) { 266 uint8_t sectionIDAndHasAlignment; 267 uint64_t length; 268 if (failed(parseByte(sectionIDAndHasAlignment)) || 269 failed(parseVarInt(length))) 270 return failure(); 271 272 // Extract the section ID and whether the section is aligned. The high bit 273 // of the ID is the alignment flag. 274 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 275 0b01111111); 276 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 277 278 // Check that the section is actually valid before trying to process its 279 // data. 280 if (sectionID >= bytecode::Section::kNumSections) 281 return emitError("invalid section ID: ", unsigned(sectionID)); 282 283 // Process the section alignment if present. 284 if (hasAlignment) { 285 uint64_t alignment; 286 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 287 return failure(); 288 } 289 290 // Parse the actual section data. 291 return parseBytes(static_cast<size_t>(length), sectionData); 292 } 293 294 Location getLoc() const { return fileLoc; } 295 296 private: 297 /// Parse a variable length encoded integer from the byte stream. This method 298 /// is a fallback when the number of bytes used to encode the value is greater 299 /// than 1, but less than the max (9). The provided `result` value can be 300 /// assumed to already contain the first byte of the value. 301 /// NOTE: This method is marked noinline to avoid pessimizing the common case 302 /// of single byte encoding. 303 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 304 // Count the number of trailing zeros in the marker byte, this indicates the 305 // number of trailing bytes that are part of the value. We use `uint32_t` 306 // here because we only care about the first byte, and so that be actually 307 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 308 // implementation). 309 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 310 assert(numBytes > 0 && numBytes <= 7 && 311 "unexpected number of trailing zeros in varint encoding"); 312 313 // Parse in the remaining bytes of the value. 314 llvm::support::ulittle64_t resultLE(result); 315 if (failed( 316 parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1))) 317 return failure(); 318 319 // Shift out the low-order bits that were used to mark how the value was 320 // encoded. 321 result = resultLE >> (numBytes + 1); 322 return success(); 323 } 324 325 /// The bytecode buffer. 326 ArrayRef<uint8_t> buffer; 327 328 /// The current iterator within the 'buffer'. 329 const uint8_t *dataIt; 330 331 /// A location for the bytecode used to report errors. 332 Location fileLoc; 333 }; 334 } // namespace 335 336 /// Resolve an index into the given entry list. `entry` may either be a 337 /// reference, in which case it is assigned to the corresponding value in 338 /// `entries`, or a pointer, in which case it is assigned to the address of the 339 /// element in `entries`. 340 template <typename RangeT, typename T> 341 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 342 uint64_t index, T &entry, 343 StringRef entryStr) { 344 if (index >= entries.size()) 345 return reader.emitError("invalid ", entryStr, " index: ", index); 346 347 // If the provided entry is a pointer, resolve to the address of the entry. 348 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 349 entry = entries[index]; 350 else 351 entry = &entries[index]; 352 return success(); 353 } 354 355 /// Parse and resolve an index into the given entry list. 356 template <typename RangeT, typename T> 357 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 358 T &entry, StringRef entryStr) { 359 uint64_t entryIdx; 360 if (failed(reader.parseVarInt(entryIdx))) 361 return failure(); 362 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // StringSectionReader 367 //===----------------------------------------------------------------------===// 368 369 namespace { 370 /// This class is used to read references to the string section from the 371 /// bytecode. 372 class StringSectionReader { 373 public: 374 /// Initialize the string section reader with the given section data. 375 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 376 377 /// Parse a shared string from the string section. The shared string is 378 /// encoded using an index to a corresponding string in the string section. 379 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 380 return parseEntry(reader, strings, result, "string"); 381 } 382 383 /// Parse a shared string from the string section. The shared string is 384 /// encoded using an index to a corresponding string in the string section. 385 /// This variant parses a flag compressed with the index. 386 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 387 bool &flag) { 388 uint64_t entryIdx; 389 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 390 return failure(); 391 return parseStringAtIndex(reader, entryIdx, result); 392 } 393 394 /// Parse a shared string from the string section. The shared string is 395 /// encoded using an index to a corresponding string in the string section. 396 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 397 StringRef &result) { 398 return resolveEntry(reader, strings, index, result, "string"); 399 } 400 401 private: 402 /// The table of strings referenced within the bytecode file. 403 SmallVector<StringRef> strings; 404 }; 405 } // namespace 406 407 LogicalResult StringSectionReader::initialize(Location fileLoc, 408 ArrayRef<uint8_t> sectionData) { 409 EncodingReader stringReader(sectionData, fileLoc); 410 411 // Parse the number of strings in the section. 412 uint64_t numStrings; 413 if (failed(stringReader.parseVarInt(numStrings))) 414 return failure(); 415 strings.resize(numStrings); 416 417 // Parse each of the strings. The sizes of the strings are encoded in reverse 418 // order, so that's the order we populate the table. 419 size_t stringDataEndOffset = sectionData.size(); 420 for (StringRef &string : llvm::reverse(strings)) { 421 uint64_t stringSize; 422 if (failed(stringReader.parseVarInt(stringSize))) 423 return failure(); 424 if (stringDataEndOffset < stringSize) { 425 return stringReader.emitError( 426 "string size exceeds the available data size"); 427 } 428 429 // Extract the string from the data, dropping the null character. 430 size_t stringOffset = stringDataEndOffset - stringSize; 431 string = StringRef( 432 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 433 stringSize - 1); 434 stringDataEndOffset = stringOffset; 435 } 436 437 // Check that the only remaining data was for the strings, i.e. the reader 438 // should be at the same offset as the first string. 439 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 440 return stringReader.emitError("unexpected trailing data between the " 441 "offsets for strings and their data"); 442 } 443 return success(); 444 } 445 446 //===----------------------------------------------------------------------===// 447 // BytecodeDialect 448 //===----------------------------------------------------------------------===// 449 450 namespace { 451 class DialectReader; 452 453 /// This struct represents a dialect entry within the bytecode. 454 struct BytecodeDialect { 455 /// Load the dialect into the provided context if it hasn't been loaded yet. 456 /// Returns failure if the dialect couldn't be loaded *and* the provided 457 /// context does not allow unregistered dialects. The provided reader is used 458 /// for error emission if necessary. 459 LogicalResult load(const DialectReader &reader, MLIRContext *ctx); 460 461 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 462 /// only be called after `load`. 463 Dialect *getLoadedDialect() const { 464 assert(dialect && 465 "expected `load` to be invoked before `getLoadedDialect`"); 466 return *dialect; 467 } 468 469 /// The loaded dialect entry. This field is std::nullopt if we haven't 470 /// attempted to load, nullptr if we failed to load, otherwise the loaded 471 /// dialect. 472 std::optional<Dialect *> dialect; 473 474 /// The bytecode interface of the dialect, or nullptr if the dialect does not 475 /// implement the bytecode interface. This field should only be checked if the 476 /// `dialect` field is not std::nullopt. 477 const BytecodeDialectInterface *interface = nullptr; 478 479 /// The name of the dialect. 480 StringRef name; 481 482 /// A buffer containing the encoding of the dialect version parsed. 483 ArrayRef<uint8_t> versionBuffer; 484 485 /// Lazy loaded dialect version from the handle above. 486 std::unique_ptr<DialectVersion> loadedVersion; 487 }; 488 489 /// This struct represents an operation name entry within the bytecode. 490 struct BytecodeOperationName { 491 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 492 std::optional<bool> wasRegistered) 493 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 494 495 /// The loaded operation name, or std::nullopt if it hasn't been processed 496 /// yet. 497 std::optional<OperationName> opName; 498 499 /// The dialect that owns this operation name. 500 BytecodeDialect *dialect; 501 502 /// The name of the operation, without the dialect prefix. 503 StringRef name; 504 505 /// Whether this operation was registered when the bytecode was produced. 506 /// This flag is populated when bytecode version >=kNativePropertiesEncoding. 507 std::optional<bool> wasRegistered; 508 }; 509 } // namespace 510 511 /// Parse a single dialect group encoded in the byte stream. 512 static LogicalResult parseDialectGrouping( 513 EncodingReader &reader, 514 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 515 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 516 // Parse the dialect and the number of entries in the group. 517 std::unique_ptr<BytecodeDialect> *dialect; 518 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 519 return failure(); 520 uint64_t numEntries; 521 if (failed(reader.parseVarInt(numEntries))) 522 return failure(); 523 524 for (uint64_t i = 0; i < numEntries; ++i) 525 if (failed(entryCallback(dialect->get()))) 526 return failure(); 527 return success(); 528 } 529 530 //===----------------------------------------------------------------------===// 531 // ResourceSectionReader 532 //===----------------------------------------------------------------------===// 533 534 namespace { 535 /// This class is used to read the resource section from the bytecode. 536 class ResourceSectionReader { 537 public: 538 /// Initialize the resource section reader with the given section data. 539 LogicalResult 540 initialize(Location fileLoc, const ParserConfig &config, 541 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 542 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 543 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 544 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 545 546 /// Parse a dialect resource handle from the resource section. 547 LogicalResult parseResourceHandle(EncodingReader &reader, 548 AsmDialectResourceHandle &result) { 549 return parseEntry(reader, dialectResources, result, "resource handle"); 550 } 551 552 private: 553 /// The table of dialect resources within the bytecode file. 554 SmallVector<AsmDialectResourceHandle> dialectResources; 555 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 556 }; 557 558 class ParsedResourceEntry : public AsmParsedResourceEntry { 559 public: 560 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 561 EncodingReader &reader, StringSectionReader &stringReader, 562 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 563 : key(key), kind(kind), reader(reader), stringReader(stringReader), 564 bufferOwnerRef(bufferOwnerRef) {} 565 ~ParsedResourceEntry() override = default; 566 567 StringRef getKey() const final { return key; } 568 569 InFlightDiagnostic emitError() const final { return reader.emitError(); } 570 571 AsmResourceEntryKind getKind() const final { return kind; } 572 573 FailureOr<bool> parseAsBool() const final { 574 if (kind != AsmResourceEntryKind::Bool) 575 return emitError() << "expected a bool resource entry, but found a " 576 << toString(kind) << " entry instead"; 577 578 bool value; 579 if (failed(reader.parseByte(value))) 580 return failure(); 581 return value; 582 } 583 FailureOr<std::string> parseAsString() const final { 584 if (kind != AsmResourceEntryKind::String) 585 return emitError() << "expected a string resource entry, but found a " 586 << toString(kind) << " entry instead"; 587 588 StringRef string; 589 if (failed(stringReader.parseString(reader, string))) 590 return failure(); 591 return string.str(); 592 } 593 594 FailureOr<AsmResourceBlob> 595 parseAsBlob(BlobAllocatorFn allocator) const final { 596 if (kind != AsmResourceEntryKind::Blob) 597 return emitError() << "expected a blob resource entry, but found a " 598 << toString(kind) << " entry instead"; 599 600 ArrayRef<uint8_t> data; 601 uint64_t alignment; 602 if (failed(reader.parseBlobAndAlignment(data, alignment))) 603 return failure(); 604 605 // If we have an extendable reference to the buffer owner, we don't need to 606 // allocate a new buffer for the data, and can use the data directly. 607 if (bufferOwnerRef) { 608 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 609 data.size()); 610 611 // Allocate an unmanager buffer which captures a reference to the owner. 612 // For now we just mark this as immutable, but in the future we should 613 // explore marking this as mutable when desired. 614 return UnmanagedAsmResourceBlob::allocateWithAlign( 615 charData, alignment, 616 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 617 } 618 619 // Allocate memory for the blob using the provided allocator and copy the 620 // data into it. 621 AsmResourceBlob blob = allocator(data.size(), alignment); 622 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 623 blob.isMutable() && 624 "blob allocator did not return a properly aligned address"); 625 memcpy(blob.getMutableData().data(), data.data(), data.size()); 626 return blob; 627 } 628 629 private: 630 StringRef key; 631 AsmResourceEntryKind kind; 632 EncodingReader &reader; 633 StringSectionReader &stringReader; 634 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 635 }; 636 } // namespace 637 638 template <typename T> 639 static LogicalResult 640 parseResourceGroup(Location fileLoc, bool allowEmpty, 641 EncodingReader &offsetReader, EncodingReader &resourceReader, 642 StringSectionReader &stringReader, T *handler, 643 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 644 function_ref<StringRef(StringRef)> remapKey = {}, 645 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 646 uint64_t numResources; 647 if (failed(offsetReader.parseVarInt(numResources))) 648 return failure(); 649 650 for (uint64_t i = 0; i < numResources; ++i) { 651 StringRef key; 652 AsmResourceEntryKind kind; 653 uint64_t resourceOffset; 654 ArrayRef<uint8_t> data; 655 if (failed(stringReader.parseString(offsetReader, key)) || 656 failed(offsetReader.parseVarInt(resourceOffset)) || 657 failed(offsetReader.parseByte(kind)) || 658 failed(resourceReader.parseBytes(resourceOffset, data))) 659 return failure(); 660 661 // Process the resource key. 662 if ((processKeyFn && failed(processKeyFn(key)))) 663 return failure(); 664 665 // If the resource data is empty and we allow it, don't error out when 666 // parsing below, just skip it. 667 if (allowEmpty && data.empty()) 668 continue; 669 670 // Ignore the entry if we don't have a valid handler. 671 if (!handler) 672 continue; 673 674 // Otherwise, parse the resource value. 675 EncodingReader entryReader(data, fileLoc); 676 key = remapKey(key); 677 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 678 bufferOwnerRef); 679 if (failed(handler->parseResource(entry))) 680 return failure(); 681 if (!entryReader.empty()) { 682 return entryReader.emitError( 683 "unexpected trailing bytes in resource entry '", key, "'"); 684 } 685 } 686 return success(); 687 } 688 689 LogicalResult ResourceSectionReader::initialize( 690 Location fileLoc, const ParserConfig &config, 691 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 692 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 693 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 694 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 695 EncodingReader resourceReader(sectionData, fileLoc); 696 EncodingReader offsetReader(offsetSectionData, fileLoc); 697 698 // Read the number of external resource providers. 699 uint64_t numExternalResourceGroups; 700 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 701 return failure(); 702 703 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 704 // provides most of the arguments. 705 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 706 function_ref<LogicalResult(StringRef)> keyFn = {}) { 707 auto resolveKey = [&](StringRef key) -> StringRef { 708 auto it = dialectResourceHandleRenamingMap.find(key); 709 if (it == dialectResourceHandleRenamingMap.end()) 710 return ""; 711 return it->second; 712 }; 713 714 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 715 stringReader, handler, bufferOwnerRef, resolveKey, 716 keyFn); 717 }; 718 719 // Read the external resources from the bytecode. 720 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 721 StringRef key; 722 if (failed(stringReader.parseString(offsetReader, key))) 723 return failure(); 724 725 // Get the handler for these resources. 726 // TODO: Should we require handling external resources in some scenarios? 727 AsmResourceParser *handler = config.getResourceParser(key); 728 if (!handler) { 729 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 730 << "'"; 731 } 732 733 if (failed(parseGroup(handler))) 734 return failure(); 735 } 736 737 // Read the dialect resources from the bytecode. 738 MLIRContext *ctx = fileLoc->getContext(); 739 while (!offsetReader.empty()) { 740 std::unique_ptr<BytecodeDialect> *dialect; 741 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 742 failed((*dialect)->load(dialectReader, ctx))) 743 return failure(); 744 Dialect *loadedDialect = (*dialect)->getLoadedDialect(); 745 if (!loadedDialect) { 746 return resourceReader.emitError() 747 << "dialect '" << (*dialect)->name << "' is unknown"; 748 } 749 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 750 if (!handler) { 751 return resourceReader.emitError() 752 << "unexpected resources for dialect '" << (*dialect)->name << "'"; 753 } 754 755 // Ensure that each resource is declared before being processed. 756 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 757 FailureOr<AsmDialectResourceHandle> handle = 758 handler->declareResource(key); 759 if (failed(handle)) { 760 return resourceReader.emitError() 761 << "unknown 'resource' key '" << key << "' for dialect '" 762 << (*dialect)->name << "'"; 763 } 764 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 765 dialectResources.push_back(*handle); 766 return success(); 767 }; 768 769 // Parse the resources for this dialect. We allow empty resources because we 770 // just treat these as declarations. 771 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 772 return failure(); 773 } 774 775 return success(); 776 } 777 778 //===----------------------------------------------------------------------===// 779 // Attribute/Type Reader 780 //===----------------------------------------------------------------------===// 781 782 namespace { 783 /// This class provides support for reading attribute and type entries from the 784 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 785 /// this reader to manage when to actually parse them from the bytecode. 786 class AttrTypeReader { 787 /// This class represents a single attribute or type entry. 788 template <typename T> 789 struct Entry { 790 /// The entry, or null if it hasn't been resolved yet. 791 T entry = {}; 792 /// The parent dialect of this entry. 793 BytecodeDialect *dialect = nullptr; 794 /// A flag indicating if the entry was encoded using a custom encoding, 795 /// instead of using the textual assembly format. 796 bool hasCustomEncoding = false; 797 /// The raw data of this entry in the bytecode. 798 ArrayRef<uint8_t> data; 799 }; 800 using AttrEntry = Entry<Attribute>; 801 using TypeEntry = Entry<Type>; 802 803 public: 804 AttrTypeReader(StringSectionReader &stringReader, 805 ResourceSectionReader &resourceReader, 806 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 807 uint64_t &bytecodeVersion, Location fileLoc, 808 const ParserConfig &config) 809 : stringReader(stringReader), resourceReader(resourceReader), 810 dialectsMap(dialectsMap), fileLoc(fileLoc), 811 bytecodeVersion(bytecodeVersion), parserConfig(config) {} 812 813 /// Initialize the attribute and type information within the reader. 814 LogicalResult 815 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 816 ArrayRef<uint8_t> sectionData, 817 ArrayRef<uint8_t> offsetSectionData); 818 819 /// Resolve the attribute or type at the given index. Returns nullptr on 820 /// failure. 821 Attribute resolveAttribute(size_t index) { 822 return resolveEntry(attributes, index, "Attribute"); 823 } 824 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 825 826 /// Parse a reference to an attribute or type using the given reader. 827 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 828 uint64_t attrIdx; 829 if (failed(reader.parseVarInt(attrIdx))) 830 return failure(); 831 result = resolveAttribute(attrIdx); 832 return success(!!result); 833 } 834 LogicalResult parseOptionalAttribute(EncodingReader &reader, 835 Attribute &result) { 836 uint64_t attrIdx; 837 bool flag; 838 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 839 return failure(); 840 if (!flag) 841 return success(); 842 result = resolveAttribute(attrIdx); 843 return success(!!result); 844 } 845 846 LogicalResult parseType(EncodingReader &reader, Type &result) { 847 uint64_t typeIdx; 848 if (failed(reader.parseVarInt(typeIdx))) 849 return failure(); 850 result = resolveType(typeIdx); 851 return success(!!result); 852 } 853 854 template <typename T> 855 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 856 Attribute baseResult; 857 if (failed(parseAttribute(reader, baseResult))) 858 return failure(); 859 if ((result = dyn_cast<T>(baseResult))) 860 return success(); 861 return reader.emitError("expected attribute of type: ", 862 llvm::getTypeName<T>(), ", but got: ", baseResult); 863 } 864 865 private: 866 /// Resolve the given entry at `index`. 867 template <typename T> 868 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 869 StringRef entryType); 870 871 /// Parse an entry using the given reader that was encoded using the textual 872 /// assembly format. 873 template <typename T> 874 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 875 StringRef entryType); 876 877 /// Parse an entry using the given reader that was encoded using a custom 878 /// bytecode format. 879 template <typename T> 880 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 881 StringRef entryType); 882 883 /// The string section reader used to resolve string references when parsing 884 /// custom encoded attribute/type entries. 885 StringSectionReader &stringReader; 886 887 /// The resource section reader used to resolve resource references when 888 /// parsing custom encoded attribute/type entries. 889 ResourceSectionReader &resourceReader; 890 891 /// The map of the loaded dialects used to retrieve dialect information, such 892 /// as the dialect version. 893 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 894 895 /// The set of attribute and type entries. 896 SmallVector<AttrEntry> attributes; 897 SmallVector<TypeEntry> types; 898 899 /// A location used for error emission. 900 Location fileLoc; 901 902 /// Current bytecode version being used. 903 uint64_t &bytecodeVersion; 904 905 /// Reference to the parser configuration. 906 const ParserConfig &parserConfig; 907 }; 908 909 class DialectReader : public DialectBytecodeReader { 910 public: 911 DialectReader(AttrTypeReader &attrTypeReader, 912 StringSectionReader &stringReader, 913 ResourceSectionReader &resourceReader, 914 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 915 EncodingReader &reader, uint64_t &bytecodeVersion) 916 : attrTypeReader(attrTypeReader), stringReader(stringReader), 917 resourceReader(resourceReader), dialectsMap(dialectsMap), 918 reader(reader), bytecodeVersion(bytecodeVersion) {} 919 920 InFlightDiagnostic emitError(const Twine &msg) const override { 921 return reader.emitError(msg); 922 } 923 924 FailureOr<const DialectVersion *> 925 getDialectVersion(StringRef dialectName) const override { 926 // First check if the dialect is available in the map. 927 auto dialectEntry = dialectsMap.find(dialectName); 928 if (dialectEntry == dialectsMap.end()) 929 return failure(); 930 // If the dialect was found, try to load it. This will trigger reading the 931 // bytecode version from the version buffer if it wasn't already processed. 932 // Return failure if either of those two actions could not be completed. 933 if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || 934 dialectEntry->getValue()->loadedVersion == nullptr) 935 return failure(); 936 return dialectEntry->getValue()->loadedVersion.get(); 937 } 938 939 MLIRContext *getContext() const override { return getLoc().getContext(); } 940 941 uint64_t getBytecodeVersion() const override { return bytecodeVersion; } 942 943 DialectReader withEncodingReader(EncodingReader &encReader) const { 944 return DialectReader(attrTypeReader, stringReader, resourceReader, 945 dialectsMap, encReader, bytecodeVersion); 946 } 947 948 Location getLoc() const { return reader.getLoc(); } 949 950 //===--------------------------------------------------------------------===// 951 // IR 952 //===--------------------------------------------------------------------===// 953 954 LogicalResult readAttribute(Attribute &result) override { 955 return attrTypeReader.parseAttribute(reader, result); 956 } 957 LogicalResult readOptionalAttribute(Attribute &result) override { 958 return attrTypeReader.parseOptionalAttribute(reader, result); 959 } 960 LogicalResult readType(Type &result) override { 961 return attrTypeReader.parseType(reader, result); 962 } 963 964 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 965 AsmDialectResourceHandle handle; 966 if (failed(resourceReader.parseResourceHandle(reader, handle))) 967 return failure(); 968 return handle; 969 } 970 971 //===--------------------------------------------------------------------===// 972 // Primitives 973 //===--------------------------------------------------------------------===// 974 975 LogicalResult readVarInt(uint64_t &result) override { 976 return reader.parseVarInt(result); 977 } 978 979 LogicalResult readSignedVarInt(int64_t &result) override { 980 uint64_t unsignedResult; 981 if (failed(reader.parseSignedVarInt(unsignedResult))) 982 return failure(); 983 result = static_cast<int64_t>(unsignedResult); 984 return success(); 985 } 986 987 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 988 // Small values are encoded using a single byte. 989 if (bitWidth <= 8) { 990 uint8_t value; 991 if (failed(reader.parseByte(value))) 992 return failure(); 993 return APInt(bitWidth, value); 994 } 995 996 // Large values up to 64 bits are encoded using a single varint. 997 if (bitWidth <= 64) { 998 uint64_t value; 999 if (failed(reader.parseSignedVarInt(value))) 1000 return failure(); 1001 return APInt(bitWidth, value); 1002 } 1003 1004 // Otherwise, for really big values we encode the array of active words in 1005 // the value. 1006 uint64_t numActiveWords; 1007 if (failed(reader.parseVarInt(numActiveWords))) 1008 return failure(); 1009 SmallVector<uint64_t, 4> words(numActiveWords); 1010 for (uint64_t i = 0; i < numActiveWords; ++i) 1011 if (failed(reader.parseSignedVarInt(words[i]))) 1012 return failure(); 1013 return APInt(bitWidth, words); 1014 } 1015 1016 FailureOr<APFloat> 1017 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 1018 FailureOr<APInt> intVal = 1019 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 1020 if (failed(intVal)) 1021 return failure(); 1022 return APFloat(semantics, *intVal); 1023 } 1024 1025 LogicalResult readString(StringRef &result) override { 1026 return stringReader.parseString(reader, result); 1027 } 1028 1029 LogicalResult readBlob(ArrayRef<char> &result) override { 1030 uint64_t dataSize; 1031 ArrayRef<uint8_t> data; 1032 if (failed(reader.parseVarInt(dataSize)) || 1033 failed(reader.parseBytes(dataSize, data))) 1034 return failure(); 1035 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 1036 data.size()); 1037 return success(); 1038 } 1039 1040 LogicalResult readBool(bool &result) override { 1041 return reader.parseByte(result); 1042 } 1043 1044 private: 1045 AttrTypeReader &attrTypeReader; 1046 StringSectionReader &stringReader; 1047 ResourceSectionReader &resourceReader; 1048 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 1049 EncodingReader &reader; 1050 uint64_t &bytecodeVersion; 1051 }; 1052 1053 /// Wraps the properties section and handles reading properties out of it. 1054 class PropertiesSectionReader { 1055 public: 1056 /// Initialize the properties section reader with the given section data. 1057 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1058 if (sectionData.empty()) 1059 return success(); 1060 EncodingReader propReader(sectionData, fileLoc); 1061 uint64_t count; 1062 if (failed(propReader.parseVarInt(count))) 1063 return failure(); 1064 // Parse the raw properties buffer. 1065 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1066 return failure(); 1067 1068 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1069 offsetTable.reserve(count); 1070 for (auto idx : llvm::seq<int64_t>(0, count)) { 1071 (void)idx; 1072 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1073 ArrayRef<uint8_t> rawProperties; 1074 uint64_t dataSize; 1075 if (failed(offsetsReader.parseVarInt(dataSize)) || 1076 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1077 return failure(); 1078 } 1079 if (!offsetsReader.empty()) 1080 return offsetsReader.emitError() 1081 << "Broken properties section: didn't exhaust the offsets table"; 1082 return success(); 1083 } 1084 1085 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1086 OperationName *opName, OperationState &opState) { 1087 uint64_t propertiesIdx; 1088 if (failed(dialectReader.readVarInt(propertiesIdx))) 1089 return failure(); 1090 if (propertiesIdx >= offsetTable.size()) 1091 return dialectReader.emitError("Properties idx out-of-bound for ") 1092 << opName->getStringRef(); 1093 size_t propertiesOffset = offsetTable[propertiesIdx]; 1094 if (propertiesIdx >= propertiesBuffers.size()) 1095 return dialectReader.emitError("Properties offset out-of-bound for ") 1096 << opName->getStringRef(); 1097 1098 // Acquire the sub-buffer that represent the requested properties. 1099 ArrayRef<char> rawProperties; 1100 { 1101 // "Seek" to the requested offset by getting a new reader with the right 1102 // sub-buffer. 1103 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1104 fileLoc); 1105 // Properties are stored as a sequence of {size + raw_data}. 1106 if (failed( 1107 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1108 return failure(); 1109 } 1110 // Setup a new reader to read from the `rawProperties` sub-buffer. 1111 EncodingReader reader( 1112 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1113 DialectReader propReader = dialectReader.withEncodingReader(reader); 1114 1115 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1116 if (iface) 1117 return iface->readProperties(propReader, opState); 1118 if (opName->isRegistered()) 1119 return propReader.emitError( 1120 "has properties but missing BytecodeOpInterface for ") 1121 << opName->getStringRef(); 1122 // Unregistered op are storing properties as an attribute. 1123 return propReader.readAttribute(opState.propertiesAttr); 1124 } 1125 1126 private: 1127 /// The properties buffer referenced within the bytecode file. 1128 ArrayRef<uint8_t> propertiesBuffers; 1129 1130 /// Table of offset in the buffer above. 1131 SmallVector<int64_t> offsetTable; 1132 }; 1133 } // namespace 1134 1135 LogicalResult AttrTypeReader::initialize( 1136 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 1137 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { 1138 EncodingReader offsetReader(offsetSectionData, fileLoc); 1139 1140 // Parse the number of attribute and type entries. 1141 uint64_t numAttributes, numTypes; 1142 if (failed(offsetReader.parseVarInt(numAttributes)) || 1143 failed(offsetReader.parseVarInt(numTypes))) 1144 return failure(); 1145 attributes.resize(numAttributes); 1146 types.resize(numTypes); 1147 1148 // A functor used to accumulate the offsets for the entries in the given 1149 // range. 1150 uint64_t currentOffset = 0; 1151 auto parseEntries = [&](auto &&range) { 1152 size_t currentIndex = 0, endIndex = range.size(); 1153 1154 // Parse an individual entry. 1155 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1156 auto &entry = range[currentIndex++]; 1157 1158 uint64_t entrySize; 1159 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1160 entry.hasCustomEncoding))) 1161 return failure(); 1162 1163 // Verify that the offset is actually valid. 1164 if (currentOffset + entrySize > sectionData.size()) { 1165 return offsetReader.emitError( 1166 "Attribute or Type entry offset points past the end of section"); 1167 } 1168 1169 entry.data = sectionData.slice(currentOffset, entrySize); 1170 entry.dialect = dialect; 1171 currentOffset += entrySize; 1172 return success(); 1173 }; 1174 while (currentIndex != endIndex) 1175 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1176 return failure(); 1177 return success(); 1178 }; 1179 1180 // Process each of the attributes, and then the types. 1181 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1182 return failure(); 1183 1184 // Ensure that we read everything from the section. 1185 if (!offsetReader.empty()) { 1186 return offsetReader.emitError( 1187 "unexpected trailing data in the Attribute/Type offset section"); 1188 } 1189 1190 return success(); 1191 } 1192 1193 template <typename T> 1194 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1195 StringRef entryType) { 1196 if (index >= entries.size()) { 1197 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1198 return {}; 1199 } 1200 1201 // If the entry has already been resolved, there is nothing left to do. 1202 Entry<T> &entry = entries[index]; 1203 if (entry.entry) 1204 return entry.entry; 1205 1206 // Parse the entry. 1207 EncodingReader reader(entry.data, fileLoc); 1208 1209 // Parse based on how the entry was encoded. 1210 if (entry.hasCustomEncoding) { 1211 if (failed(parseCustomEntry(entry, reader, entryType))) 1212 return T(); 1213 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1214 return T(); 1215 } 1216 1217 if (!reader.empty()) { 1218 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1219 return T(); 1220 } 1221 return entry.entry; 1222 } 1223 1224 template <typename T> 1225 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1226 StringRef entryType) { 1227 StringRef asmStr; 1228 if (failed(reader.parseNullTerminatedString(asmStr))) 1229 return failure(); 1230 1231 // Invoke the MLIR assembly parser to parse the entry text. 1232 size_t numRead = 0; 1233 MLIRContext *context = fileLoc->getContext(); 1234 if constexpr (std::is_same_v<T, Type>) 1235 result = 1236 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1237 else 1238 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1239 /*isKnownNullTerminated=*/true); 1240 if (!result) 1241 return failure(); 1242 1243 // Ensure there weren't dangling characters after the entry. 1244 if (numRead != asmStr.size()) { 1245 return reader.emitError("trailing characters found after ", entryType, 1246 " assembly format: ", asmStr.drop_front(numRead)); 1247 } 1248 return success(); 1249 } 1250 1251 template <typename T> 1252 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1253 EncodingReader &reader, 1254 StringRef entryType) { 1255 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, 1256 reader, bytecodeVersion); 1257 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1258 return failure(); 1259 1260 if constexpr (std::is_same_v<T, Type>) { 1261 // Try parsing with callbacks first if available. 1262 for (const auto &callback : 1263 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { 1264 if (failed( 1265 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1266 return failure(); 1267 // Early return if parsing was successful. 1268 if (!!entry.entry) 1269 return success(); 1270 1271 // Reset the reader if we failed to parse, so we can fall through the 1272 // other parsing functions. 1273 reader = EncodingReader(entry.data, reader.getLoc()); 1274 } 1275 } else { 1276 // Try parsing with callbacks first if available. 1277 for (const auto &callback : 1278 parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { 1279 if (failed( 1280 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1281 return failure(); 1282 // Early return if parsing was successful. 1283 if (!!entry.entry) 1284 return success(); 1285 1286 // Reset the reader if we failed to parse, so we can fall through the 1287 // other parsing functions. 1288 reader = EncodingReader(entry.data, reader.getLoc()); 1289 } 1290 } 1291 1292 // Ensure that the dialect implements the bytecode interface. 1293 if (!entry.dialect->interface) { 1294 return reader.emitError("dialect '", entry.dialect->name, 1295 "' does not implement the bytecode interface"); 1296 } 1297 1298 if constexpr (std::is_same_v<T, Type>) 1299 entry.entry = entry.dialect->interface->readType(dialectReader); 1300 else 1301 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1302 1303 return success(!!entry.entry); 1304 } 1305 1306 //===----------------------------------------------------------------------===// 1307 // Bytecode Reader 1308 //===----------------------------------------------------------------------===// 1309 1310 /// This class is used to read a bytecode buffer and translate it into MLIR. 1311 class mlir::BytecodeReader::Impl { 1312 struct RegionReadState; 1313 using LazyLoadableOpsInfo = 1314 std::list<std::pair<Operation *, RegionReadState>>; 1315 using LazyLoadableOpsMap = 1316 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1317 1318 public: 1319 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1320 llvm::MemoryBufferRef buffer, 1321 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1322 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1323 attrTypeReader(stringReader, resourceReader, dialectsMap, version, 1324 fileLoc, config), 1325 // Use the builtin unrealized conversion cast operation to represent 1326 // forward references to values that aren't yet defined. 1327 forwardRefOpState(UnknownLoc::get(config.getContext()), 1328 "builtin.unrealized_conversion_cast", ValueRange(), 1329 NoneType::get(config.getContext())), 1330 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1331 1332 /// Read the bytecode defined within `buffer` into the given block. 1333 LogicalResult read(Block *block, 1334 llvm::function_ref<bool(Operation *)> lazyOps); 1335 1336 /// Return the number of ops that haven't been materialized yet. 1337 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1338 1339 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1340 1341 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1342 /// newly found lazy operation. 1343 LogicalResult 1344 materialize(Operation *op, 1345 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1346 this->lazyOpsCallback = lazyOpsCallback; 1347 auto resetlazyOpsCallback = 1348 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1349 auto it = lazyLoadableOpsMap.find(op); 1350 assert(it != lazyLoadableOpsMap.end() && 1351 "materialize called on non-materializable op"); 1352 return materialize(it); 1353 } 1354 1355 /// Materialize all operations. 1356 LogicalResult materializeAll() { 1357 while (!lazyLoadableOpsMap.empty()) { 1358 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1359 return failure(); 1360 } 1361 return success(); 1362 } 1363 1364 /// Finalize the lazy-loading by calling back with every op that hasn't been 1365 /// materialized to let the client decide if the op should be deleted or 1366 /// materialized. The op is materialized if the callback returns true, deleted 1367 /// otherwise. 1368 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1369 while (!lazyLoadableOps.empty()) { 1370 Operation *op = lazyLoadableOps.begin()->first; 1371 if (shouldMaterialize(op)) { 1372 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1373 return failure(); 1374 continue; 1375 } 1376 op->dropAllReferences(); 1377 op->erase(); 1378 lazyLoadableOps.pop_front(); 1379 lazyLoadableOpsMap.erase(op); 1380 } 1381 return success(); 1382 } 1383 1384 private: 1385 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1386 assert(it != lazyLoadableOpsMap.end() && 1387 "materialize called on non-materializable op"); 1388 valueScopes.emplace_back(); 1389 std::vector<RegionReadState> regionStack; 1390 regionStack.push_back(std::move(it->getSecond()->second)); 1391 lazyLoadableOps.erase(it->getSecond()); 1392 lazyLoadableOpsMap.erase(it); 1393 1394 while (!regionStack.empty()) 1395 if (failed(parseRegions(regionStack, regionStack.back()))) 1396 return failure(); 1397 return success(); 1398 } 1399 1400 /// Return the context for this config. 1401 MLIRContext *getContext() const { return config.getContext(); } 1402 1403 /// Parse the bytecode version. 1404 LogicalResult parseVersion(EncodingReader &reader); 1405 1406 //===--------------------------------------------------------------------===// 1407 // Dialect Section 1408 1409 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1410 1411 /// Parse an operation name reference using the given reader, and set the 1412 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1413 /// context where opName was registered. 1414 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1415 std::optional<bool> &wasRegistered); 1416 1417 //===--------------------------------------------------------------------===// 1418 // Attribute/Type Section 1419 1420 /// Parse an attribute or type using the given reader. 1421 template <typename T> 1422 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1423 return attrTypeReader.parseAttribute(reader, result); 1424 } 1425 LogicalResult parseType(EncodingReader &reader, Type &result) { 1426 return attrTypeReader.parseType(reader, result); 1427 } 1428 1429 //===--------------------------------------------------------------------===// 1430 // Resource Section 1431 1432 LogicalResult 1433 parseResourceSection(EncodingReader &reader, 1434 std::optional<ArrayRef<uint8_t>> resourceData, 1435 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1436 1437 //===--------------------------------------------------------------------===// 1438 // IR Section 1439 1440 /// This struct represents the current read state of a range of regions. This 1441 /// struct is used to enable iterative parsing of regions. 1442 struct RegionReadState { 1443 RegionReadState(Operation *op, EncodingReader *reader, 1444 bool isIsolatedFromAbove) 1445 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1446 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1447 bool isIsolatedFromAbove) 1448 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1449 isIsolatedFromAbove(isIsolatedFromAbove) {} 1450 1451 /// The current regions being read. 1452 MutableArrayRef<Region>::iterator curRegion, endRegion; 1453 /// This is the reader to use for this region, this pointer is pointing to 1454 /// the parent region reader unless the current region is IsolatedFromAbove, 1455 /// in which case the pointer is pointing to the `owningReader` which is a 1456 /// section dedicated to the current region. 1457 EncodingReader *reader; 1458 std::unique_ptr<EncodingReader> owningReader; 1459 1460 /// The number of values defined immediately within this region. 1461 unsigned numValues = 0; 1462 1463 /// The current blocks of the region being read. 1464 SmallVector<Block *> curBlocks; 1465 Region::iterator curBlock = {}; 1466 1467 /// The number of operations remaining to be read from the current block 1468 /// being read. 1469 uint64_t numOpsRemaining = 0; 1470 1471 /// A flag indicating if the regions being read are isolated from above. 1472 bool isIsolatedFromAbove = false; 1473 }; 1474 1475 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1476 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1477 RegionReadState &readState); 1478 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1479 RegionReadState &readState, 1480 bool &isIsolatedFromAbove); 1481 1482 LogicalResult parseRegion(RegionReadState &readState); 1483 LogicalResult parseBlockHeader(EncodingReader &reader, 1484 RegionReadState &readState); 1485 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1486 1487 //===--------------------------------------------------------------------===// 1488 // Value Processing 1489 1490 /// Parse an operand reference using the given reader. Returns nullptr in the 1491 /// case of failure. 1492 Value parseOperand(EncodingReader &reader); 1493 1494 /// Sequentially define the given value range. 1495 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1496 1497 /// Create a value to use for a forward reference. 1498 Value createForwardRef(); 1499 1500 //===--------------------------------------------------------------------===// 1501 // Use-list order helpers 1502 1503 /// This struct is a simple storage that contains information required to 1504 /// reorder the use-list of a value with respect to the pre-order traversal 1505 /// ordering. 1506 struct UseListOrderStorage { 1507 UseListOrderStorage(bool isIndexPairEncoding, 1508 SmallVector<unsigned, 4> &&indices) 1509 : indices(std::move(indices)), 1510 isIndexPairEncoding(isIndexPairEncoding){}; 1511 /// The vector containing the information required to reorder the 1512 /// use-list of a value. 1513 SmallVector<unsigned, 4> indices; 1514 1515 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1516 /// indexing, such as `dst = order[src]`. 1517 bool isIndexPairEncoding; 1518 }; 1519 1520 /// Parse use-list order from bytecode for a range of values if available. The 1521 /// range is expected to be either a block argument or an op result range. On 1522 /// success, return a map of the position in the range and the use-list order 1523 /// encoding. The function assumes to know the size of the range it is 1524 /// processing. 1525 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1526 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1527 uint64_t rangeSize); 1528 1529 /// Shuffle the use-chain according to the order parsed. 1530 LogicalResult sortUseListOrder(Value value); 1531 1532 /// Recursively visit all the values defined within topLevelOp and sort the 1533 /// use-list orders according to the indices parsed. 1534 LogicalResult processUseLists(Operation *topLevelOp); 1535 1536 //===--------------------------------------------------------------------===// 1537 // Fields 1538 1539 /// This class represents a single value scope, in which a value scope is 1540 /// delimited by isolated from above regions. 1541 struct ValueScope { 1542 /// Push a new region state onto this scope, reserving enough values for 1543 /// those defined within the current region of the provided state. 1544 void push(RegionReadState &readState) { 1545 nextValueIDs.push_back(values.size()); 1546 values.resize(values.size() + readState.numValues); 1547 } 1548 1549 /// Pop the values defined for the current region within the provided region 1550 /// state. 1551 void pop(RegionReadState &readState) { 1552 values.resize(values.size() - readState.numValues); 1553 nextValueIDs.pop_back(); 1554 } 1555 1556 /// The set of values defined in this scope. 1557 std::vector<Value> values; 1558 1559 /// The ID for the next defined value for each region current being 1560 /// processed in this scope. 1561 SmallVector<unsigned, 4> nextValueIDs; 1562 }; 1563 1564 /// The configuration of the parser. 1565 const ParserConfig &config; 1566 1567 /// A location to use when emitting errors. 1568 Location fileLoc; 1569 1570 /// Flag that indicates if lazyloading is enabled. 1571 bool lazyLoading; 1572 1573 /// Keep track of operations that have been lazy loaded (their regions haven't 1574 /// been materialized), along with the `RegionReadState` that allows to 1575 /// lazy-load the regions nested under the operation. 1576 LazyLoadableOpsInfo lazyLoadableOps; 1577 LazyLoadableOpsMap lazyLoadableOpsMap; 1578 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1579 1580 /// The reader used to process attribute and types within the bytecode. 1581 AttrTypeReader attrTypeReader; 1582 1583 /// The version of the bytecode being read. 1584 uint64_t version = 0; 1585 1586 /// The producer of the bytecode being read. 1587 StringRef producer; 1588 1589 /// The table of IR units referenced within the bytecode file. 1590 SmallVector<std::unique_ptr<BytecodeDialect>> dialects; 1591 llvm::StringMap<BytecodeDialect *> dialectsMap; 1592 SmallVector<BytecodeOperationName> opNames; 1593 1594 /// The reader used to process resources within the bytecode. 1595 ResourceSectionReader resourceReader; 1596 1597 /// Worklist of values with custom use-list orders to process before the end 1598 /// of the parsing. 1599 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1600 1601 /// The table of strings referenced within the bytecode file. 1602 StringSectionReader stringReader; 1603 1604 /// The table of properties referenced by the operation in the bytecode file. 1605 PropertiesSectionReader propertiesReader; 1606 1607 /// The current set of available IR value scopes. 1608 std::vector<ValueScope> valueScopes; 1609 1610 /// The global pre-order operation ordering. 1611 DenseMap<Operation *, unsigned> operationIDs; 1612 1613 /// A block containing the set of operations defined to create forward 1614 /// references. 1615 Block forwardRefOps; 1616 1617 /// A block containing previously created, and no longer used, forward 1618 /// reference operations. 1619 Block openForwardRefOps; 1620 1621 /// An operation state used when instantiating forward references. 1622 OperationState forwardRefOpState; 1623 1624 /// Reference to the input buffer. 1625 llvm::MemoryBufferRef buffer; 1626 1627 /// The optional owning source manager, which when present may be used to 1628 /// extend the lifetime of the input buffer. 1629 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1630 }; 1631 1632 LogicalResult BytecodeReader::Impl::read( 1633 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1634 EncodingReader reader(buffer.getBuffer(), fileLoc); 1635 this->lazyOpsCallback = lazyOpsCallback; 1636 auto resetlazyOpsCallback = 1637 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1638 1639 // Skip over the bytecode header, this should have already been checked. 1640 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1641 return failure(); 1642 // Parse the bytecode version and producer. 1643 if (failed(parseVersion(reader)) || 1644 failed(reader.parseNullTerminatedString(producer))) 1645 return failure(); 1646 1647 // Add a diagnostic handler that attaches a note that includes the original 1648 // producer of the bytecode. 1649 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1650 diag.attachNote() << "in bytecode version " << version 1651 << " produced by: " << producer; 1652 return failure(); 1653 }); 1654 1655 // Parse the raw data for each of the top-level sections of the bytecode. 1656 std::optional<ArrayRef<uint8_t>> 1657 sectionDatas[bytecode::Section::kNumSections]; 1658 while (!reader.empty()) { 1659 // Read the next section from the bytecode. 1660 bytecode::Section::ID sectionID; 1661 ArrayRef<uint8_t> sectionData; 1662 if (failed(reader.parseSection(sectionID, sectionData))) 1663 return failure(); 1664 1665 // Check for duplicate sections, we only expect one instance of each. 1666 if (sectionDatas[sectionID]) { 1667 return reader.emitError("duplicate top-level section: ", 1668 ::toString(sectionID)); 1669 } 1670 sectionDatas[sectionID] = sectionData; 1671 } 1672 // Check that all of the required sections were found. 1673 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1674 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1675 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1676 return reader.emitError("missing data for top-level section: ", 1677 ::toString(sectionID)); 1678 } 1679 } 1680 1681 // Process the string section first. 1682 if (failed(stringReader.initialize( 1683 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1684 return failure(); 1685 1686 // Process the properties section. 1687 if (sectionDatas[bytecode::Section::kProperties] && 1688 failed(propertiesReader.initialize( 1689 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1690 return failure(); 1691 1692 // Process the dialect section. 1693 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1694 return failure(); 1695 1696 // Process the resource section if present. 1697 if (failed(parseResourceSection( 1698 reader, sectionDatas[bytecode::Section::kResource], 1699 sectionDatas[bytecode::Section::kResourceOffset]))) 1700 return failure(); 1701 1702 // Process the attribute and type section. 1703 if (failed(attrTypeReader.initialize( 1704 dialects, *sectionDatas[bytecode::Section::kAttrType], 1705 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1706 return failure(); 1707 1708 // Finally, process the IR section. 1709 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1710 } 1711 1712 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1713 if (failed(reader.parseVarInt(version))) 1714 return failure(); 1715 1716 // Validate the bytecode version. 1717 uint64_t currentVersion = bytecode::kVersion; 1718 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1719 if (version < minSupportedVersion) { 1720 return reader.emitError("bytecode version ", version, 1721 " is older than the current version of ", 1722 currentVersion, ", and upgrade is not supported"); 1723 } 1724 if (version > currentVersion) { 1725 return reader.emitError("bytecode version ", version, 1726 " is newer than the current version ", 1727 currentVersion); 1728 } 1729 // Override any request to lazy-load if the bytecode version is too old. 1730 if (version < bytecode::kLazyLoading) 1731 lazyLoading = false; 1732 return success(); 1733 } 1734 1735 //===----------------------------------------------------------------------===// 1736 // Dialect Section 1737 1738 LogicalResult BytecodeDialect::load(const DialectReader &reader, 1739 MLIRContext *ctx) { 1740 if (dialect) 1741 return success(); 1742 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1743 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1744 return reader.emitError("dialect '") 1745 << name 1746 << "' is unknown. If this is intended, please call " 1747 "allowUnregisteredDialects() on the MLIRContext, or use " 1748 "-allow-unregistered-dialect with the MLIR tool used."; 1749 } 1750 dialect = loadedDialect; 1751 1752 // If the dialect was actually loaded, check to see if it has a bytecode 1753 // interface. 1754 if (loadedDialect) 1755 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1756 if (!versionBuffer.empty()) { 1757 if (!interface) 1758 return reader.emitError("dialect '") 1759 << name 1760 << "' does not implement the bytecode interface, " 1761 "but found a version entry"; 1762 EncodingReader encReader(versionBuffer, reader.getLoc()); 1763 DialectReader versionReader = reader.withEncodingReader(encReader); 1764 loadedVersion = interface->readVersion(versionReader); 1765 if (!loadedVersion) 1766 return failure(); 1767 } 1768 return success(); 1769 } 1770 1771 LogicalResult 1772 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1773 EncodingReader sectionReader(sectionData, fileLoc); 1774 1775 // Parse the number of dialects in the section. 1776 uint64_t numDialects; 1777 if (failed(sectionReader.parseVarInt(numDialects))) 1778 return failure(); 1779 dialects.resize(numDialects); 1780 1781 // Parse each of the dialects. 1782 for (uint64_t i = 0; i < numDialects; ++i) { 1783 dialects[i] = std::make_unique<BytecodeDialect>(); 1784 /// Before version kDialectVersioning, there wasn't any versioning available 1785 /// for dialects, and the entryIdx represent the string itself. 1786 if (version < bytecode::kDialectVersioning) { 1787 if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) 1788 return failure(); 1789 continue; 1790 } 1791 1792 // Parse ID representing dialect and version. 1793 uint64_t dialectNameIdx; 1794 bool versionAvailable; 1795 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1796 versionAvailable))) 1797 return failure(); 1798 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1799 dialects[i]->name))) 1800 return failure(); 1801 if (versionAvailable) { 1802 bytecode::Section::ID sectionID; 1803 if (failed(sectionReader.parseSection(sectionID, 1804 dialects[i]->versionBuffer))) 1805 return failure(); 1806 if (sectionID != bytecode::Section::kDialectVersions) { 1807 emitError(fileLoc, "expected dialect version section"); 1808 return failure(); 1809 } 1810 } 1811 dialectsMap[dialects[i]->name] = dialects[i].get(); 1812 } 1813 1814 // Parse the operation names, which are grouped by dialect. 1815 auto parseOpName = [&](BytecodeDialect *dialect) { 1816 StringRef opName; 1817 std::optional<bool> wasRegistered; 1818 // Prior to version kNativePropertiesEncoding, the information about wheter 1819 // an op was registered or not wasn't encoded. 1820 if (version < bytecode::kNativePropertiesEncoding) { 1821 if (failed(stringReader.parseString(sectionReader, opName))) 1822 return failure(); 1823 } else { 1824 bool wasRegisteredFlag; 1825 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1826 wasRegisteredFlag))) 1827 return failure(); 1828 wasRegistered = wasRegisteredFlag; 1829 } 1830 opNames.emplace_back(dialect, opName, wasRegistered); 1831 return success(); 1832 }; 1833 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation 1834 // where the number of ops are known. 1835 if (version >= bytecode::kElideUnknownBlockArgLocation) { 1836 uint64_t numOps; 1837 if (failed(sectionReader.parseVarInt(numOps))) 1838 return failure(); 1839 opNames.reserve(numOps); 1840 } 1841 while (!sectionReader.empty()) 1842 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1843 return failure(); 1844 return success(); 1845 } 1846 1847 FailureOr<OperationName> 1848 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1849 std::optional<bool> &wasRegistered) { 1850 BytecodeOperationName *opName = nullptr; 1851 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1852 return failure(); 1853 wasRegistered = opName->wasRegistered; 1854 // Check to see if this operation name has already been resolved. If we 1855 // haven't, load the dialect and build the operation name. 1856 if (!opName->opName) { 1857 // Load the dialect and its version. 1858 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1859 dialectsMap, reader, version); 1860 if (failed(opName->dialect->load(dialectReader, getContext()))) 1861 return failure(); 1862 // If the opName is empty, this is because we use to accept names such as 1863 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1864 // format anymore but for now we'll be backward compatible. This can only 1865 // happen with unregistered dialects. 1866 if (opName->name.empty()) { 1867 if (opName->dialect->getLoadedDialect()) 1868 return emitError(fileLoc) << "has an empty opname for dialect '" 1869 << opName->dialect->name << "'\n"; 1870 1871 opName->opName.emplace(opName->dialect->name, getContext()); 1872 } else { 1873 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1874 getContext()); 1875 } 1876 } 1877 return *opName->opName; 1878 } 1879 1880 //===----------------------------------------------------------------------===// 1881 // Resource Section 1882 1883 LogicalResult BytecodeReader::Impl::parseResourceSection( 1884 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1885 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1886 // Ensure both sections are either present or not. 1887 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1888 if (resourceOffsetData) 1889 return emitError(fileLoc, "unexpected resource offset section when " 1890 "resource section is not present"); 1891 return emitError( 1892 fileLoc, 1893 "expected resource offset section when resource section is present"); 1894 } 1895 1896 // If the resource sections are absent, there is nothing to do. 1897 if (!resourceData) 1898 return success(); 1899 1900 // Initialize the resource reader with the resource sections. 1901 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1902 dialectsMap, reader, version); 1903 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1904 *resourceData, *resourceOffsetData, 1905 dialectReader, bufferOwnerRef); 1906 } 1907 1908 //===----------------------------------------------------------------------===// 1909 // UseListOrder Helpers 1910 1911 FailureOr<BytecodeReader::Impl::UseListMapT> 1912 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1913 uint64_t numResults) { 1914 BytecodeReader::Impl::UseListMapT map; 1915 uint64_t numValuesToRead = 1; 1916 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1917 return failure(); 1918 1919 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1920 uint64_t resultIdx = 0; 1921 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1922 return failure(); 1923 1924 uint64_t numValues; 1925 bool indexPairEncoding; 1926 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1927 return failure(); 1928 1929 SmallVector<unsigned, 4> useListOrders; 1930 for (size_t idx = 0; idx < numValues; idx++) { 1931 uint64_t index; 1932 if (failed(reader.parseVarInt(index))) 1933 return failure(); 1934 useListOrders.push_back(index); 1935 } 1936 1937 // Store in a map the result index 1938 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1939 std::move(useListOrders))); 1940 } 1941 1942 return map; 1943 } 1944 1945 /// Sorts each use according to the order specified in the use-list parsed. If 1946 /// the custom use-list is not found, this means that the order needs to be 1947 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1948 /// on the same operation, the order will follow the reverse operand number 1949 /// ordering. 1950 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1951 // Early return for trivial use-lists. 1952 if (value.use_empty() || value.hasOneUse()) 1953 return success(); 1954 1955 bool hasIncomingOrder = 1956 valueToUseListMap.contains(value.getAsOpaquePointer()); 1957 1958 // Compute the current order of the use-list with respect to the global 1959 // ordering. Detect if the order is already sorted while doing so. 1960 bool alreadySorted = true; 1961 auto &firstUse = *value.use_begin(); 1962 uint64_t prevID = 1963 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1964 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1965 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1966 uint64_t currentID = bytecode::getUseID( 1967 item.value(), operationIDs.at(item.value().getOwner())); 1968 alreadySorted &= prevID > currentID; 1969 currentOrder.push_back({item.index(), currentID}); 1970 prevID = currentID; 1971 } 1972 1973 // If the order is already sorted, and there wasn't a custom order to apply 1974 // from the bytecode file, we are done. 1975 if (alreadySorted && !hasIncomingOrder) 1976 return success(); 1977 1978 // If not already sorted, sort the indices of the current order by descending 1979 // useIDs. 1980 if (!alreadySorted) 1981 std::sort( 1982 currentOrder.begin(), currentOrder.end(), 1983 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1984 1985 if (!hasIncomingOrder) { 1986 // If the bytecode file did not contain any custom use-list order, it means 1987 // that the order was descending useID. Hence, shuffle by the first index 1988 // of the `currentOrder` pair. 1989 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1990 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1991 value.shuffleUseList(shuffle); 1992 return success(); 1993 } 1994 1995 // Pull the custom order info from the map. 1996 UseListOrderStorage customOrder = 1997 valueToUseListMap.at(value.getAsOpaquePointer()); 1998 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 1999 uint64_t numUses = 2000 std::distance(value.getUses().begin(), value.getUses().end()); 2001 2002 // If the encoding was a pair of indices `(src, dst)` for every permutation, 2003 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 2004 // as identity, and then apply the mapping encoded in the indices. 2005 if (customOrder.isIndexPairEncoding) { 2006 // Return failure if the number of indices was not representing pairs. 2007 if (shuffle.size() & 1) 2008 return failure(); 2009 2010 SmallVector<unsigned, 4> newShuffle(numUses); 2011 size_t idx = 0; 2012 std::iota(newShuffle.begin(), newShuffle.end(), idx); 2013 for (idx = 0; idx < shuffle.size(); idx += 2) 2014 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 2015 2016 shuffle = std::move(newShuffle); 2017 } 2018 2019 // Make sure that the indices represent a valid mapping. That is, the sum of 2020 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 2021 // duplicates are allowed in the list. 2022 DenseSet<unsigned> set; 2023 uint64_t accumulator = 0; 2024 for (const auto &elem : shuffle) { 2025 if (set.contains(elem)) 2026 return failure(); 2027 accumulator += elem; 2028 set.insert(elem); 2029 } 2030 if (numUses != shuffle.size() || 2031 accumulator != (((numUses - 1) * numUses) >> 1)) 2032 return failure(); 2033 2034 // Apply the current ordering map onto the shuffle vector to get the final 2035 // use-list sorting indices before shuffling. 2036 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 2037 currentOrder, [&](auto item) { return shuffle[item.first]; })); 2038 value.shuffleUseList(shuffle); 2039 return success(); 2040 } 2041 2042 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 2043 // Precompute operation IDs according to the pre-order walk of the IR. We 2044 // can't do this while parsing since parseRegions ordering is not strictly 2045 // equal to the pre-order walk. 2046 unsigned operationID = 0; 2047 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 2048 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 2049 2050 auto blockWalk = topLevelOp->walk([this](Block *block) { 2051 for (auto arg : block->getArguments()) 2052 if (failed(sortUseListOrder(arg))) 2053 return WalkResult::interrupt(); 2054 return WalkResult::advance(); 2055 }); 2056 2057 auto resultWalk = topLevelOp->walk([this](Operation *op) { 2058 for (auto result : op->getResults()) 2059 if (failed(sortUseListOrder(result))) 2060 return WalkResult::interrupt(); 2061 return WalkResult::advance(); 2062 }); 2063 2064 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 2065 } 2066 2067 //===----------------------------------------------------------------------===// 2068 // IR Section 2069 2070 LogicalResult 2071 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 2072 Block *block) { 2073 EncodingReader reader(sectionData, fileLoc); 2074 2075 // A stack of operation regions currently being read from the bytecode. 2076 std::vector<RegionReadState> regionStack; 2077 2078 // Parse the top-level block using a temporary module operation. 2079 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 2080 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 2081 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 2082 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 2083 if (failed(parseBlockHeader(reader, regionStack.back()))) 2084 return failure(); 2085 valueScopes.emplace_back(); 2086 valueScopes.back().push(regionStack.back()); 2087 2088 // Iteratively parse regions until everything has been resolved. 2089 while (!regionStack.empty()) 2090 if (failed(parseRegions(regionStack, regionStack.back()))) 2091 return failure(); 2092 if (!forwardRefOps.empty()) { 2093 return reader.emitError( 2094 "not all forward unresolved forward operand references"); 2095 } 2096 2097 // Sort use-lists according to what specified in bytecode. 2098 if (failed(processUseLists(*moduleOp))) 2099 return reader.emitError( 2100 "parsed use-list orders were invalid and could not be applied"); 2101 2102 // Resolve dialect version. 2103 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { 2104 // Parsing is complete, give an opportunity to each dialect to visit the 2105 // IR and perform upgrades. 2106 if (!byteCodeDialect->loadedVersion) 2107 continue; 2108 if (byteCodeDialect->interface && 2109 failed(byteCodeDialect->interface->upgradeFromVersion( 2110 *moduleOp, *byteCodeDialect->loadedVersion))) 2111 return failure(); 2112 } 2113 2114 // Verify that the parsed operations are valid. 2115 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2116 return failure(); 2117 2118 // Splice the parsed operations over to the provided top-level block. 2119 auto &parsedOps = moduleOp->getBody()->getOperations(); 2120 auto &destOps = block->getOperations(); 2121 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2122 return success(); 2123 } 2124 2125 LogicalResult 2126 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2127 RegionReadState &readState) { 2128 // Process regions, blocks, and operations until the end or if a nested 2129 // region is encountered. In this case we push a new state in regionStack and 2130 // return, the processing of the current region will resume afterward. 2131 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2132 // If the current block hasn't been setup yet, parse the header for this 2133 // region. The current block is already setup when this function was 2134 // interrupted to recurse down in a nested region and we resume the current 2135 // block after processing the nested region. 2136 if (readState.curBlock == Region::iterator()) { 2137 if (failed(parseRegion(readState))) 2138 return failure(); 2139 2140 // If the region is empty, there is nothing to more to do. 2141 if (readState.curRegion->empty()) 2142 continue; 2143 } 2144 2145 // Parse the blocks within the region. 2146 EncodingReader &reader = *readState.reader; 2147 do { 2148 while (readState.numOpsRemaining--) { 2149 // Read in the next operation. We don't read its regions directly, we 2150 // handle those afterwards as necessary. 2151 bool isIsolatedFromAbove = false; 2152 FailureOr<Operation *> op = 2153 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2154 if (failed(op)) 2155 return failure(); 2156 2157 // If the op has regions, add it to the stack for processing and return: 2158 // we stop the processing of the current region and resume it after the 2159 // inner one is completed. Unless LazyLoading is activated in which case 2160 // nested region parsing is delayed. 2161 if ((*op)->getNumRegions()) { 2162 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2163 2164 // Isolated regions are encoded as a section in version 2 and above. 2165 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { 2166 bytecode::Section::ID sectionID; 2167 ArrayRef<uint8_t> sectionData; 2168 if (failed(reader.parseSection(sectionID, sectionData))) 2169 return failure(); 2170 if (sectionID != bytecode::Section::kIR) 2171 return emitError(fileLoc, "expected IR section for region"); 2172 childState.owningReader = 2173 std::make_unique<EncodingReader>(sectionData, fileLoc); 2174 childState.reader = childState.owningReader.get(); 2175 2176 // If the user has a callback set, they have the opportunity to 2177 // control lazyloading as we go. 2178 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) { 2179 lazyLoadableOps.emplace_back(*op, std::move(childState)); 2180 lazyLoadableOpsMap.try_emplace(*op, 2181 std::prev(lazyLoadableOps.end())); 2182 continue; 2183 } 2184 } 2185 regionStack.push_back(std::move(childState)); 2186 2187 // If the op is isolated from above, push a new value scope. 2188 if (isIsolatedFromAbove) 2189 valueScopes.emplace_back(); 2190 return success(); 2191 } 2192 } 2193 2194 // Move to the next block of the region. 2195 if (++readState.curBlock == readState.curRegion->end()) 2196 break; 2197 if (failed(parseBlockHeader(reader, readState))) 2198 return failure(); 2199 } while (true); 2200 2201 // Reset the current block and any values reserved for this region. 2202 readState.curBlock = {}; 2203 valueScopes.back().pop(readState); 2204 } 2205 2206 // When the regions have been fully parsed, pop them off of the read stack. If 2207 // the regions were isolated from above, we also pop the last value scope. 2208 if (readState.isIsolatedFromAbove) { 2209 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2210 valueScopes.pop_back(); 2211 } 2212 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2213 regionStack.pop_back(); 2214 return success(); 2215 } 2216 2217 FailureOr<Operation *> 2218 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2219 RegionReadState &readState, 2220 bool &isIsolatedFromAbove) { 2221 // Parse the name of the operation. 2222 std::optional<bool> wasRegistered; 2223 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2224 if (failed(opName)) 2225 return failure(); 2226 2227 // Parse the operation mask, which indicates which components of the operation 2228 // are present. 2229 uint8_t opMask; 2230 if (failed(reader.parseByte(opMask))) 2231 return failure(); 2232 2233 /// Parse the location. 2234 LocationAttr opLoc; 2235 if (failed(parseAttribute(reader, opLoc))) 2236 return failure(); 2237 2238 // With the location and name resolved, we can start building the operation 2239 // state. 2240 OperationState opState(opLoc, *opName); 2241 2242 // Parse the attributes of the operation. 2243 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2244 DictionaryAttr dictAttr; 2245 if (failed(parseAttribute(reader, dictAttr))) 2246 return failure(); 2247 opState.attributes = dictAttr; 2248 } 2249 2250 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2251 // kHasProperties wasn't emitted in older bytecode, we should never get 2252 // there without also having the `wasRegistered` flag available. 2253 if (!wasRegistered) 2254 return emitError(fileLoc, 2255 "Unexpected missing `wasRegistered` opname flag at " 2256 "bytecode version ") 2257 << version << " with properties."; 2258 // When an operation is emitted without being registered, the properties are 2259 // stored as an attribute. Otherwise the op must implement the bytecode 2260 // interface and control the serialization. 2261 if (wasRegistered) { 2262 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2263 dialectsMap, reader, version); 2264 if (failed( 2265 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2266 return failure(); 2267 } else { 2268 // If the operation wasn't registered when it was emitted, the properties 2269 // was serialized as an attribute. 2270 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2271 return failure(); 2272 } 2273 } 2274 2275 /// Parse the results of the operation. 2276 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2277 uint64_t numResults; 2278 if (failed(reader.parseVarInt(numResults))) 2279 return failure(); 2280 opState.types.resize(numResults); 2281 for (int i = 0, e = numResults; i < e; ++i) 2282 if (failed(parseType(reader, opState.types[i]))) 2283 return failure(); 2284 } 2285 2286 /// Parse the operands of the operation. 2287 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2288 uint64_t numOperands; 2289 if (failed(reader.parseVarInt(numOperands))) 2290 return failure(); 2291 opState.operands.resize(numOperands); 2292 for (int i = 0, e = numOperands; i < e; ++i) 2293 if (!(opState.operands[i] = parseOperand(reader))) 2294 return failure(); 2295 } 2296 2297 /// Parse the successors of the operation. 2298 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2299 uint64_t numSuccs; 2300 if (failed(reader.parseVarInt(numSuccs))) 2301 return failure(); 2302 opState.successors.resize(numSuccs); 2303 for (int i = 0, e = numSuccs; i < e; ++i) { 2304 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2305 "successor"))) 2306 return failure(); 2307 } 2308 } 2309 2310 /// Parse the use-list orders for the results of the operation. Use-list 2311 /// orders are available since version 3 of the bytecode. 2312 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2313 if (version >= bytecode::kUseListOrdering && 2314 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2315 size_t numResults = opState.types.size(); 2316 auto parseResult = parseUseListOrderForRange(reader, numResults); 2317 if (failed(parseResult)) 2318 return failure(); 2319 resultIdxToUseListMap = std::move(*parseResult); 2320 } 2321 2322 /// Parse the regions of the operation. 2323 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2324 uint64_t numRegions; 2325 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2326 return failure(); 2327 2328 opState.regions.reserve(numRegions); 2329 for (int i = 0, e = numRegions; i < e; ++i) 2330 opState.regions.push_back(std::make_unique<Region>()); 2331 } 2332 2333 // Create the operation at the back of the current block. 2334 Operation *op = Operation::create(opState); 2335 readState.curBlock->push_back(op); 2336 2337 // If the operation had results, update the value references. 2338 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2339 return failure(); 2340 2341 /// Store a map for every value that received a custom use-list order from the 2342 /// bytecode file. 2343 if (resultIdxToUseListMap.has_value()) { 2344 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2345 if (resultIdxToUseListMap->contains(idx)) { 2346 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2347 resultIdxToUseListMap->at(idx)); 2348 } 2349 } 2350 } 2351 return op; 2352 } 2353 2354 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2355 EncodingReader &reader = *readState.reader; 2356 2357 // Parse the number of blocks in the region. 2358 uint64_t numBlocks; 2359 if (failed(reader.parseVarInt(numBlocks))) 2360 return failure(); 2361 2362 // If the region is empty, there is nothing else to do. 2363 if (numBlocks == 0) 2364 return success(); 2365 2366 // Parse the number of values defined in this region. 2367 uint64_t numValues; 2368 if (failed(reader.parseVarInt(numValues))) 2369 return failure(); 2370 readState.numValues = numValues; 2371 2372 // Create the blocks within this region. We do this before processing so that 2373 // we can rely on the blocks existing when creating operations. 2374 readState.curBlocks.clear(); 2375 readState.curBlocks.reserve(numBlocks); 2376 for (uint64_t i = 0; i < numBlocks; ++i) { 2377 readState.curBlocks.push_back(new Block()); 2378 readState.curRegion->push_back(readState.curBlocks.back()); 2379 } 2380 2381 // Prepare the current value scope for this region. 2382 valueScopes.back().push(readState); 2383 2384 // Parse the entry block of the region. 2385 readState.curBlock = readState.curRegion->begin(); 2386 return parseBlockHeader(reader, readState); 2387 } 2388 2389 LogicalResult 2390 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2391 RegionReadState &readState) { 2392 bool hasArgs; 2393 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2394 return failure(); 2395 2396 // Parse the arguments of the block. 2397 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2398 return failure(); 2399 2400 // Uselist orders are available since version 3 of the bytecode. 2401 if (version < bytecode::kUseListOrdering) 2402 return success(); 2403 2404 uint8_t hasUseListOrders = 0; 2405 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2406 return failure(); 2407 2408 if (!hasUseListOrders) 2409 return success(); 2410 2411 Block &blk = *readState.curBlock; 2412 auto argIdxToUseListMap = 2413 parseUseListOrderForRange(reader, blk.getNumArguments()); 2414 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2415 return failure(); 2416 2417 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2418 if (argIdxToUseListMap->contains(idx)) 2419 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2420 argIdxToUseListMap->at(idx)); 2421 2422 // We don't parse the operations of the block here, that's done elsewhere. 2423 return success(); 2424 } 2425 2426 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2427 Block *block) { 2428 // Parse the value ID for the first argument, and the number of arguments. 2429 uint64_t numArgs; 2430 if (failed(reader.parseVarInt(numArgs))) 2431 return failure(); 2432 2433 SmallVector<Type> argTypes; 2434 SmallVector<Location> argLocs; 2435 argTypes.reserve(numArgs); 2436 argLocs.reserve(numArgs); 2437 2438 Location unknownLoc = UnknownLoc::get(config.getContext()); 2439 while (numArgs--) { 2440 Type argType; 2441 LocationAttr argLoc = unknownLoc; 2442 if (version >= bytecode::kElideUnknownBlockArgLocation) { 2443 // Parse the type with hasLoc flag to determine if it has type. 2444 uint64_t typeIdx; 2445 bool hasLoc; 2446 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2447 !(argType = attrTypeReader.resolveType(typeIdx))) 2448 return failure(); 2449 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2450 return failure(); 2451 } else { 2452 // All args has type and location. 2453 if (failed(parseType(reader, argType)) || 2454 failed(parseAttribute(reader, argLoc))) 2455 return failure(); 2456 } 2457 argTypes.push_back(argType); 2458 argLocs.push_back(argLoc); 2459 } 2460 block->addArguments(argTypes, argLocs); 2461 return defineValues(reader, block->getArguments()); 2462 } 2463 2464 //===----------------------------------------------------------------------===// 2465 // Value Processing 2466 2467 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2468 std::vector<Value> &values = valueScopes.back().values; 2469 Value *value = nullptr; 2470 if (failed(parseEntry(reader, values, value, "value"))) 2471 return Value(); 2472 2473 // Create a new forward reference if necessary. 2474 if (!*value) 2475 *value = createForwardRef(); 2476 return *value; 2477 } 2478 2479 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2480 ValueRange newValues) { 2481 ValueScope &valueScope = valueScopes.back(); 2482 std::vector<Value> &values = valueScope.values; 2483 2484 unsigned &valueID = valueScope.nextValueIDs.back(); 2485 unsigned valueIDEnd = valueID + newValues.size(); 2486 if (valueIDEnd > values.size()) { 2487 return reader.emitError( 2488 "value index range was outside of the expected range for " 2489 "the parent region, got [", 2490 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2491 values.size() - 1); 2492 } 2493 2494 // Assign the values and update any forward references. 2495 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2496 Value newValue = newValues[i]; 2497 2498 // Check to see if a definition for this value already exists. 2499 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2500 Operation *forwardRefOp = oldValue.getDefiningOp(); 2501 2502 // Assert that this is a forward reference operation. Given how we compute 2503 // definition ids (incrementally as we parse), it shouldn't be possible 2504 // for the value to be defined any other way. 2505 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2506 "value index was already defined?"); 2507 2508 oldValue.replaceAllUsesWith(newValue); 2509 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2510 } 2511 } 2512 return success(); 2513 } 2514 2515 Value BytecodeReader::Impl::createForwardRef() { 2516 // Check for an avaliable existing operation to use. Otherwise, create a new 2517 // fake operation to use for the reference. 2518 if (!openForwardRefOps.empty()) { 2519 Operation *op = &openForwardRefOps.back(); 2520 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2521 } else { 2522 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2523 } 2524 return forwardRefOps.back().getResult(0); 2525 } 2526 2527 //===----------------------------------------------------------------------===// 2528 // Entry Points 2529 //===----------------------------------------------------------------------===// 2530 2531 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2532 2533 BytecodeReader::BytecodeReader( 2534 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2535 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2536 Location sourceFileLoc = 2537 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2538 /*line=*/0, /*column=*/0); 2539 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2540 bufferOwnerRef); 2541 } 2542 2543 LogicalResult BytecodeReader::readTopLevel( 2544 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2545 return impl->read(block, lazyOpsCallback); 2546 } 2547 2548 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2549 return impl->getNumOpsToMaterialize(); 2550 } 2551 2552 bool BytecodeReader::isMaterializable(Operation *op) { 2553 return impl->isMaterializable(op); 2554 } 2555 2556 LogicalResult BytecodeReader::materialize( 2557 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2558 return impl->materialize(op, lazyOpsCallback); 2559 } 2560 2561 LogicalResult 2562 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2563 return impl->finalize(shouldMaterialize); 2564 } 2565 2566 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2567 return buffer.getBuffer().starts_with("ML\xefR"); 2568 } 2569 2570 /// Read the bytecode from the provided memory buffer reference. 2571 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2572 /// and may be used to extend the lifetime of the buffer. 2573 static LogicalResult 2574 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2575 const ParserConfig &config, 2576 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2577 Location sourceFileLoc = 2578 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2579 /*line=*/0, /*column=*/0); 2580 if (!isBytecode(buffer)) { 2581 return emitError(sourceFileLoc, 2582 "input buffer is not an MLIR bytecode file"); 2583 } 2584 2585 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2586 buffer, bufferOwnerRef); 2587 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2588 } 2589 2590 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2591 const ParserConfig &config) { 2592 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2593 } 2594 LogicalResult 2595 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2596 Block *block, const ParserConfig &config) { 2597 return readBytecodeFileImpl( 2598 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2599 sourceMgr); 2600 } 2601