1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Bytecode/BytecodeReader.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Bytecode/BytecodeImplementation.h" 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Bytecode/Encoding.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/Verifier.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Support/LogicalResult.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/ScopeExit.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/Endian.h" 26 #include "llvm/Support/MemoryBufferRef.h" 27 #include "llvm/Support/SourceMgr.h" 28 29 #include <cstddef> 30 #include <list> 31 #include <memory> 32 #include <numeric> 33 #include <optional> 34 35 #define DEBUG_TYPE "mlir-bytecode-reader" 36 37 using namespace mlir; 38 39 /// Stringify the given section ID. 40 static std::string toString(bytecode::Section::ID sectionID) { 41 switch (sectionID) { 42 case bytecode::Section::kString: 43 return "String (0)"; 44 case bytecode::Section::kDialect: 45 return "Dialect (1)"; 46 case bytecode::Section::kAttrType: 47 return "AttrType (2)"; 48 case bytecode::Section::kAttrTypeOffset: 49 return "AttrTypeOffset (3)"; 50 case bytecode::Section::kIR: 51 return "IR (4)"; 52 case bytecode::Section::kResource: 53 return "Resource (5)"; 54 case bytecode::Section::kResourceOffset: 55 return "ResourceOffset (6)"; 56 case bytecode::Section::kDialectVersions: 57 return "DialectVersions (7)"; 58 case bytecode::Section::kProperties: 59 return "Properties (8)"; 60 default: 61 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); 62 } 63 } 64 65 /// Returns true if the given top-level section ID is optional. 66 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { 67 switch (sectionID) { 68 case bytecode::Section::kString: 69 case bytecode::Section::kDialect: 70 case bytecode::Section::kAttrType: 71 case bytecode::Section::kAttrTypeOffset: 72 case bytecode::Section::kIR: 73 return false; 74 case bytecode::Section::kResource: 75 case bytecode::Section::kResourceOffset: 76 case bytecode::Section::kDialectVersions: 77 return true; 78 case bytecode::Section::kProperties: 79 return version < bytecode::kNativePropertiesEncoding; 80 default: 81 llvm_unreachable("unknown section ID"); 82 } 83 } 84 85 //===----------------------------------------------------------------------===// 86 // EncodingReader 87 //===----------------------------------------------------------------------===// 88 89 namespace { 90 class EncodingReader { 91 public: 92 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) 93 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {} 94 explicit EncodingReader(StringRef contents, Location fileLoc) 95 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), 96 contents.size()}, 97 fileLoc) {} 98 99 /// Returns true if the entire section has been read. 100 bool empty() const { return dataIt == buffer.end(); } 101 102 /// Returns the remaining size of the bytecode. 103 size_t size() const { return buffer.end() - dataIt; } 104 105 /// Align the current reader position to the specified alignment. 106 LogicalResult alignTo(unsigned alignment) { 107 if (!llvm::isPowerOf2_32(alignment)) 108 return emitError("expected alignment to be a power-of-two"); 109 110 auto isUnaligned = [&](const uint8_t *ptr) { 111 return ((uintptr_t)ptr & (alignment - 1)) != 0; 112 }; 113 114 // Ensure the data buffer was sufficiently aligned in the first place. 115 if (LLVM_UNLIKELY(isUnaligned(buffer.begin()))) { 116 return emitError("expected bytecode buffer to be aligned to ", alignment, 117 ", but got pointer: '0x" + 118 llvm::utohexstr((uintptr_t)buffer.begin()) + "'"); 119 } 120 121 // Shift the reader position to the next alignment boundary. 122 while (isUnaligned(dataIt)) { 123 uint8_t padding; 124 if (failed(parseByte(padding))) 125 return failure(); 126 if (padding != bytecode::kAlignmentByte) { 127 return emitError("expected alignment byte (0xCB), but got: '0x" + 128 llvm::utohexstr(padding) + "'"); 129 } 130 } 131 132 // Ensure the data iterator is now aligned. This case is unlikely because we 133 // *just* went through the effort to align the data iterator. 134 if (LLVM_UNLIKELY(isUnaligned(dataIt))) { 135 return emitError("expected data iterator aligned to ", alignment, 136 ", but got pointer: '0x" + 137 llvm::utohexstr((uintptr_t)dataIt) + "'"); 138 } 139 140 return success(); 141 } 142 143 /// Emit an error using the given arguments. 144 template <typename... Args> 145 InFlightDiagnostic emitError(Args &&...args) const { 146 return ::emitError(fileLoc).append(std::forward<Args>(args)...); 147 } 148 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } 149 150 /// Parse a single byte from the stream. 151 template <typename T> 152 LogicalResult parseByte(T &value) { 153 if (empty()) 154 return emitError("attempting to parse a byte at the end of the bytecode"); 155 value = static_cast<T>(*dataIt++); 156 return success(); 157 } 158 /// Parse a range of bytes of 'length' into the given result. 159 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { 160 if (length > size()) { 161 return emitError("attempting to parse ", length, " bytes when only ", 162 size(), " remain"); 163 } 164 result = {dataIt, length}; 165 dataIt += length; 166 return success(); 167 } 168 /// Parse a range of bytes of 'length' into the given result, which can be 169 /// assumed to be large enough to hold `length`. 170 LogicalResult parseBytes(size_t length, uint8_t *result) { 171 if (length > size()) { 172 return emitError("attempting to parse ", length, " bytes when only ", 173 size(), " remain"); 174 } 175 memcpy(result, dataIt, length); 176 dataIt += length; 177 return success(); 178 } 179 180 /// Parse an aligned blob of data, where the alignment was encoded alongside 181 /// the data. 182 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, 183 uint64_t &alignment) { 184 uint64_t dataSize; 185 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || 186 failed(alignTo(alignment))) 187 return failure(); 188 return parseBytes(dataSize, data); 189 } 190 191 /// Parse a variable length encoded integer from the byte stream. The first 192 /// encoded byte contains a prefix in the low bits indicating the encoded 193 /// length of the value. This length prefix is a bit sequence of '0's followed 194 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes 195 /// (not including the prefix byte). All remaining bits in the first byte, 196 /// along with all of the bits in additional bytes, provide the value of the 197 /// integer encoded in little-endian order. 198 LogicalResult parseVarInt(uint64_t &result) { 199 // Parse the first byte of the encoding, which contains the length prefix. 200 if (failed(parseByte(result))) 201 return failure(); 202 203 // Handle the overwhelmingly common case where the value is stored in a 204 // single byte. In this case, the first bit is the `1` marker bit. 205 if (LLVM_LIKELY(result & 1)) { 206 result >>= 1; 207 return success(); 208 } 209 210 // Handle the overwhelming uncommon case where the value required all 8 211 // bytes (i.e. a really really big number). In this case, the marker byte is 212 // all zeros: `00000000`. 213 if (LLVM_UNLIKELY(result == 0)) { 214 llvm::support::ulittle64_t resultLE; 215 if (failed(parseBytes(sizeof(resultLE), 216 reinterpret_cast<uint8_t *>(&resultLE)))) 217 return failure(); 218 result = resultLE; 219 return success(); 220 } 221 return parseMultiByteVarInt(result); 222 } 223 224 /// Parse a signed variable length encoded integer from the byte stream. A 225 /// signed varint is encoded as a normal varint with zigzag encoding applied, 226 /// i.e. the low bit of the value is used to indicate the sign. 227 LogicalResult parseSignedVarInt(uint64_t &result) { 228 if (failed(parseVarInt(result))) 229 return failure(); 230 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) 231 result = (result >> 1) ^ (~(result & 1) + 1); 232 return success(); 233 } 234 235 /// Parse a variable length encoded integer whose low bit is used to encode an 236 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. 237 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { 238 if (failed(parseVarInt(result))) 239 return failure(); 240 flag = result & 1; 241 result >>= 1; 242 return success(); 243 } 244 245 /// Skip the first `length` bytes within the reader. 246 LogicalResult skipBytes(size_t length) { 247 if (length > size()) { 248 return emitError("attempting to skip ", length, " bytes when only ", 249 size(), " remain"); 250 } 251 dataIt += length; 252 return success(); 253 } 254 255 /// Parse a null-terminated string into `result` (without including the NUL 256 /// terminator). 257 LogicalResult parseNullTerminatedString(StringRef &result) { 258 const char *startIt = (const char *)dataIt; 259 const char *nulIt = (const char *)memchr(startIt, 0, size()); 260 if (!nulIt) 261 return emitError( 262 "malformed null-terminated string, no null character found"); 263 264 result = StringRef(startIt, nulIt - startIt); 265 dataIt = (const uint8_t *)nulIt + 1; 266 return success(); 267 } 268 269 /// Parse a section header, placing the kind of section in `sectionID` and the 270 /// contents of the section in `sectionData`. 271 LogicalResult parseSection(bytecode::Section::ID §ionID, 272 ArrayRef<uint8_t> §ionData) { 273 uint8_t sectionIDAndHasAlignment; 274 uint64_t length; 275 if (failed(parseByte(sectionIDAndHasAlignment)) || 276 failed(parseVarInt(length))) 277 return failure(); 278 279 // Extract the section ID and whether the section is aligned. The high bit 280 // of the ID is the alignment flag. 281 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & 282 0b01111111); 283 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; 284 285 // Check that the section is actually valid before trying to process its 286 // data. 287 if (sectionID >= bytecode::Section::kNumSections) 288 return emitError("invalid section ID: ", unsigned(sectionID)); 289 290 // Process the section alignment if present. 291 if (hasAlignment) { 292 uint64_t alignment; 293 if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) 294 return failure(); 295 } 296 297 // Parse the actual section data. 298 return parseBytes(static_cast<size_t>(length), sectionData); 299 } 300 301 Location getLoc() const { return fileLoc; } 302 303 private: 304 /// Parse a variable length encoded integer from the byte stream. This method 305 /// is a fallback when the number of bytes used to encode the value is greater 306 /// than 1, but less than the max (9). The provided `result` value can be 307 /// assumed to already contain the first byte of the value. 308 /// NOTE: This method is marked noinline to avoid pessimizing the common case 309 /// of single byte encoding. 310 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { 311 // Count the number of trailing zeros in the marker byte, this indicates the 312 // number of trailing bytes that are part of the value. We use `uint32_t` 313 // here because we only care about the first byte, and so that be actually 314 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop 315 // implementation). 316 uint32_t numBytes = llvm::countr_zero<uint32_t>(result); 317 assert(numBytes > 0 && numBytes <= 7 && 318 "unexpected number of trailing zeros in varint encoding"); 319 320 // Parse in the remaining bytes of the value. 321 llvm::support::ulittle64_t resultLE(result); 322 if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1))) 323 return failure(); 324 325 // Shift out the low-order bits that were used to mark how the value was 326 // encoded. 327 result = resultLE >> (numBytes + 1); 328 return success(); 329 } 330 331 /// The bytecode buffer. 332 ArrayRef<uint8_t> buffer; 333 334 /// The current iterator within the 'buffer'. 335 const uint8_t *dataIt; 336 337 /// A location for the bytecode used to report errors. 338 Location fileLoc; 339 }; 340 } // namespace 341 342 /// Resolve an index into the given entry list. `entry` may either be a 343 /// reference, in which case it is assigned to the corresponding value in 344 /// `entries`, or a pointer, in which case it is assigned to the address of the 345 /// element in `entries`. 346 template <typename RangeT, typename T> 347 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, 348 uint64_t index, T &entry, 349 StringRef entryStr) { 350 if (index >= entries.size()) 351 return reader.emitError("invalid ", entryStr, " index: ", index); 352 353 // If the provided entry is a pointer, resolve to the address of the entry. 354 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) 355 entry = entries[index]; 356 else 357 entry = &entries[index]; 358 return success(); 359 } 360 361 /// Parse and resolve an index into the given entry list. 362 template <typename RangeT, typename T> 363 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, 364 T &entry, StringRef entryStr) { 365 uint64_t entryIdx; 366 if (failed(reader.parseVarInt(entryIdx))) 367 return failure(); 368 return resolveEntry(reader, entries, entryIdx, entry, entryStr); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // StringSectionReader 373 //===----------------------------------------------------------------------===// 374 375 namespace { 376 /// This class is used to read references to the string section from the 377 /// bytecode. 378 class StringSectionReader { 379 public: 380 /// Initialize the string section reader with the given section data. 381 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); 382 383 /// Parse a shared string from the string section. The shared string is 384 /// encoded using an index to a corresponding string in the string section. 385 LogicalResult parseString(EncodingReader &reader, StringRef &result) { 386 return parseEntry(reader, strings, result, "string"); 387 } 388 389 /// Parse a shared string from the string section. The shared string is 390 /// encoded using an index to a corresponding string in the string section. 391 /// This variant parses a flag compressed with the index. 392 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, 393 bool &flag) { 394 uint64_t entryIdx; 395 if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) 396 return failure(); 397 return parseStringAtIndex(reader, entryIdx, result); 398 } 399 400 /// Parse a shared string from the string section. The shared string is 401 /// encoded using an index to a corresponding string in the string section. 402 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, 403 StringRef &result) { 404 return resolveEntry(reader, strings, index, result, "string"); 405 } 406 407 private: 408 /// The table of strings referenced within the bytecode file. 409 SmallVector<StringRef> strings; 410 }; 411 } // namespace 412 413 LogicalResult StringSectionReader::initialize(Location fileLoc, 414 ArrayRef<uint8_t> sectionData) { 415 EncodingReader stringReader(sectionData, fileLoc); 416 417 // Parse the number of strings in the section. 418 uint64_t numStrings; 419 if (failed(stringReader.parseVarInt(numStrings))) 420 return failure(); 421 strings.resize(numStrings); 422 423 // Parse each of the strings. The sizes of the strings are encoded in reverse 424 // order, so that's the order we populate the table. 425 size_t stringDataEndOffset = sectionData.size(); 426 for (StringRef &string : llvm::reverse(strings)) { 427 uint64_t stringSize; 428 if (failed(stringReader.parseVarInt(stringSize))) 429 return failure(); 430 if (stringDataEndOffset < stringSize) { 431 return stringReader.emitError( 432 "string size exceeds the available data size"); 433 } 434 435 // Extract the string from the data, dropping the null character. 436 size_t stringOffset = stringDataEndOffset - stringSize; 437 string = StringRef( 438 reinterpret_cast<const char *>(sectionData.data() + stringOffset), 439 stringSize - 1); 440 stringDataEndOffset = stringOffset; 441 } 442 443 // Check that the only remaining data was for the strings, i.e. the reader 444 // should be at the same offset as the first string. 445 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { 446 return stringReader.emitError("unexpected trailing data between the " 447 "offsets for strings and their data"); 448 } 449 return success(); 450 } 451 452 //===----------------------------------------------------------------------===// 453 // BytecodeDialect 454 //===----------------------------------------------------------------------===// 455 456 namespace { 457 class DialectReader; 458 459 /// This struct represents a dialect entry within the bytecode. 460 struct BytecodeDialect { 461 /// Load the dialect into the provided context if it hasn't been loaded yet. 462 /// Returns failure if the dialect couldn't be loaded *and* the provided 463 /// context does not allow unregistered dialects. The provided reader is used 464 /// for error emission if necessary. 465 LogicalResult load(const DialectReader &reader, MLIRContext *ctx); 466 467 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can 468 /// only be called after `load`. 469 Dialect *getLoadedDialect() const { 470 assert(dialect && 471 "expected `load` to be invoked before `getLoadedDialect`"); 472 return *dialect; 473 } 474 475 /// The loaded dialect entry. This field is std::nullopt if we haven't 476 /// attempted to load, nullptr if we failed to load, otherwise the loaded 477 /// dialect. 478 std::optional<Dialect *> dialect; 479 480 /// The bytecode interface of the dialect, or nullptr if the dialect does not 481 /// implement the bytecode interface. This field should only be checked if the 482 /// `dialect` field is not std::nullopt. 483 const BytecodeDialectInterface *interface = nullptr; 484 485 /// The name of the dialect. 486 StringRef name; 487 488 /// A buffer containing the encoding of the dialect version parsed. 489 ArrayRef<uint8_t> versionBuffer; 490 491 /// Lazy loaded dialect version from the handle above. 492 std::unique_ptr<DialectVersion> loadedVersion; 493 }; 494 495 /// This struct represents an operation name entry within the bytecode. 496 struct BytecodeOperationName { 497 BytecodeOperationName(BytecodeDialect *dialect, StringRef name, 498 std::optional<bool> wasRegistered) 499 : dialect(dialect), name(name), wasRegistered(wasRegistered) {} 500 501 /// The loaded operation name, or std::nullopt if it hasn't been processed 502 /// yet. 503 std::optional<OperationName> opName; 504 505 /// The dialect that owns this operation name. 506 BytecodeDialect *dialect; 507 508 /// The name of the operation, without the dialect prefix. 509 StringRef name; 510 511 /// Whether this operation was registered when the bytecode was produced. 512 /// This flag is populated when bytecode version >=kNativePropertiesEncoding. 513 std::optional<bool> wasRegistered; 514 }; 515 } // namespace 516 517 /// Parse a single dialect group encoded in the byte stream. 518 static LogicalResult parseDialectGrouping( 519 EncodingReader &reader, 520 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 521 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { 522 // Parse the dialect and the number of entries in the group. 523 std::unique_ptr<BytecodeDialect> *dialect; 524 if (failed(parseEntry(reader, dialects, dialect, "dialect"))) 525 return failure(); 526 uint64_t numEntries; 527 if (failed(reader.parseVarInt(numEntries))) 528 return failure(); 529 530 for (uint64_t i = 0; i < numEntries; ++i) 531 if (failed(entryCallback(dialect->get()))) 532 return failure(); 533 return success(); 534 } 535 536 //===----------------------------------------------------------------------===// 537 // ResourceSectionReader 538 //===----------------------------------------------------------------------===// 539 540 namespace { 541 /// This class is used to read the resource section from the bytecode. 542 class ResourceSectionReader { 543 public: 544 /// Initialize the resource section reader with the given section data. 545 LogicalResult 546 initialize(Location fileLoc, const ParserConfig &config, 547 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 548 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 549 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 550 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); 551 552 /// Parse a dialect resource handle from the resource section. 553 LogicalResult parseResourceHandle(EncodingReader &reader, 554 AsmDialectResourceHandle &result) { 555 return parseEntry(reader, dialectResources, result, "resource handle"); 556 } 557 558 private: 559 /// The table of dialect resources within the bytecode file. 560 SmallVector<AsmDialectResourceHandle> dialectResources; 561 llvm::StringMap<std::string> dialectResourceHandleRenamingMap; 562 }; 563 564 class ParsedResourceEntry : public AsmParsedResourceEntry { 565 public: 566 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, 567 EncodingReader &reader, StringSectionReader &stringReader, 568 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 569 : key(key), kind(kind), reader(reader), stringReader(stringReader), 570 bufferOwnerRef(bufferOwnerRef) {} 571 ~ParsedResourceEntry() override = default; 572 573 StringRef getKey() const final { return key; } 574 575 InFlightDiagnostic emitError() const final { return reader.emitError(); } 576 577 AsmResourceEntryKind getKind() const final { return kind; } 578 579 FailureOr<bool> parseAsBool() const final { 580 if (kind != AsmResourceEntryKind::Bool) 581 return emitError() << "expected a bool resource entry, but found a " 582 << toString(kind) << " entry instead"; 583 584 bool value; 585 if (failed(reader.parseByte(value))) 586 return failure(); 587 return value; 588 } 589 FailureOr<std::string> parseAsString() const final { 590 if (kind != AsmResourceEntryKind::String) 591 return emitError() << "expected a string resource entry, but found a " 592 << toString(kind) << " entry instead"; 593 594 StringRef string; 595 if (failed(stringReader.parseString(reader, string))) 596 return failure(); 597 return string.str(); 598 } 599 600 FailureOr<AsmResourceBlob> 601 parseAsBlob(BlobAllocatorFn allocator) const final { 602 if (kind != AsmResourceEntryKind::Blob) 603 return emitError() << "expected a blob resource entry, but found a " 604 << toString(kind) << " entry instead"; 605 606 ArrayRef<uint8_t> data; 607 uint64_t alignment; 608 if (failed(reader.parseBlobAndAlignment(data, alignment))) 609 return failure(); 610 611 // If we have an extendable reference to the buffer owner, we don't need to 612 // allocate a new buffer for the data, and can use the data directly. 613 if (bufferOwnerRef) { 614 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), 615 data.size()); 616 617 // Allocate an unmanager buffer which captures a reference to the owner. 618 // For now we just mark this as immutable, but in the future we should 619 // explore marking this as mutable when desired. 620 return UnmanagedAsmResourceBlob::allocateWithAlign( 621 charData, alignment, 622 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); 623 } 624 625 // Allocate memory for the blob using the provided allocator and copy the 626 // data into it. 627 AsmResourceBlob blob = allocator(data.size(), alignment); 628 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && 629 blob.isMutable() && 630 "blob allocator did not return a properly aligned address"); 631 memcpy(blob.getMutableData().data(), data.data(), data.size()); 632 return blob; 633 } 634 635 private: 636 StringRef key; 637 AsmResourceEntryKind kind; 638 EncodingReader &reader; 639 StringSectionReader &stringReader; 640 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 641 }; 642 } // namespace 643 644 template <typename T> 645 static LogicalResult 646 parseResourceGroup(Location fileLoc, bool allowEmpty, 647 EncodingReader &offsetReader, EncodingReader &resourceReader, 648 StringSectionReader &stringReader, T *handler, 649 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, 650 function_ref<StringRef(StringRef)> remapKey = {}, 651 function_ref<LogicalResult(StringRef)> processKeyFn = {}) { 652 uint64_t numResources; 653 if (failed(offsetReader.parseVarInt(numResources))) 654 return failure(); 655 656 for (uint64_t i = 0; i < numResources; ++i) { 657 StringRef key; 658 AsmResourceEntryKind kind; 659 uint64_t resourceOffset; 660 ArrayRef<uint8_t> data; 661 if (failed(stringReader.parseString(offsetReader, key)) || 662 failed(offsetReader.parseVarInt(resourceOffset)) || 663 failed(offsetReader.parseByte(kind)) || 664 failed(resourceReader.parseBytes(resourceOffset, data))) 665 return failure(); 666 667 // Process the resource key. 668 if ((processKeyFn && failed(processKeyFn(key)))) 669 return failure(); 670 671 // If the resource data is empty and we allow it, don't error out when 672 // parsing below, just skip it. 673 if (allowEmpty && data.empty()) 674 continue; 675 676 // Ignore the entry if we don't have a valid handler. 677 if (!handler) 678 continue; 679 680 // Otherwise, parse the resource value. 681 EncodingReader entryReader(data, fileLoc); 682 key = remapKey(key); 683 ParsedResourceEntry entry(key, kind, entryReader, stringReader, 684 bufferOwnerRef); 685 if (failed(handler->parseResource(entry))) 686 return failure(); 687 if (!entryReader.empty()) { 688 return entryReader.emitError( 689 "unexpected trailing bytes in resource entry '", key, "'"); 690 } 691 } 692 return success(); 693 } 694 695 LogicalResult ResourceSectionReader::initialize( 696 Location fileLoc, const ParserConfig &config, 697 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 698 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, 699 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, 700 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 701 EncodingReader resourceReader(sectionData, fileLoc); 702 EncodingReader offsetReader(offsetSectionData, fileLoc); 703 704 // Read the number of external resource providers. 705 uint64_t numExternalResourceGroups; 706 if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) 707 return failure(); 708 709 // Utility functor that dispatches to `parseResourceGroup`, but implicitly 710 // provides most of the arguments. 711 auto parseGroup = [&](auto *handler, bool allowEmpty = false, 712 function_ref<LogicalResult(StringRef)> keyFn = {}) { 713 auto resolveKey = [&](StringRef key) -> StringRef { 714 auto it = dialectResourceHandleRenamingMap.find(key); 715 if (it == dialectResourceHandleRenamingMap.end()) 716 return ""; 717 return it->second; 718 }; 719 720 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, 721 stringReader, handler, bufferOwnerRef, resolveKey, 722 keyFn); 723 }; 724 725 // Read the external resources from the bytecode. 726 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { 727 StringRef key; 728 if (failed(stringReader.parseString(offsetReader, key))) 729 return failure(); 730 731 // Get the handler for these resources. 732 // TODO: Should we require handling external resources in some scenarios? 733 AsmResourceParser *handler = config.getResourceParser(key); 734 if (!handler) { 735 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key 736 << "'"; 737 } 738 739 if (failed(parseGroup(handler))) 740 return failure(); 741 } 742 743 // Read the dialect resources from the bytecode. 744 MLIRContext *ctx = fileLoc->getContext(); 745 while (!offsetReader.empty()) { 746 std::unique_ptr<BytecodeDialect> *dialect; 747 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || 748 failed((*dialect)->load(dialectReader, ctx))) 749 return failure(); 750 Dialect *loadedDialect = (*dialect)->getLoadedDialect(); 751 if (!loadedDialect) { 752 return resourceReader.emitError() 753 << "dialect '" << (*dialect)->name << "' is unknown"; 754 } 755 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); 756 if (!handler) { 757 return resourceReader.emitError() 758 << "unexpected resources for dialect '" << (*dialect)->name << "'"; 759 } 760 761 // Ensure that each resource is declared before being processed. 762 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { 763 FailureOr<AsmDialectResourceHandle> handle = 764 handler->declareResource(key); 765 if (failed(handle)) { 766 return resourceReader.emitError() 767 << "unknown 'resource' key '" << key << "' for dialect '" 768 << (*dialect)->name << "'"; 769 } 770 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); 771 dialectResources.push_back(*handle); 772 return success(); 773 }; 774 775 // Parse the resources for this dialect. We allow empty resources because we 776 // just treat these as declarations. 777 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) 778 return failure(); 779 } 780 781 return success(); 782 } 783 784 //===----------------------------------------------------------------------===// 785 // Attribute/Type Reader 786 //===----------------------------------------------------------------------===// 787 788 namespace { 789 /// This class provides support for reading attribute and type entries from the 790 /// bytecode. Attribute and Type entries are read lazily on demand, so we use 791 /// this reader to manage when to actually parse them from the bytecode. 792 class AttrTypeReader { 793 /// This class represents a single attribute or type entry. 794 template <typename T> 795 struct Entry { 796 /// The entry, or null if it hasn't been resolved yet. 797 T entry = {}; 798 /// The parent dialect of this entry. 799 BytecodeDialect *dialect = nullptr; 800 /// A flag indicating if the entry was encoded using a custom encoding, 801 /// instead of using the textual assembly format. 802 bool hasCustomEncoding = false; 803 /// The raw data of this entry in the bytecode. 804 ArrayRef<uint8_t> data; 805 }; 806 using AttrEntry = Entry<Attribute>; 807 using TypeEntry = Entry<Type>; 808 809 public: 810 AttrTypeReader(StringSectionReader &stringReader, 811 ResourceSectionReader &resourceReader, 812 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 813 uint64_t &bytecodeVersion, Location fileLoc, 814 const ParserConfig &config) 815 : stringReader(stringReader), resourceReader(resourceReader), 816 dialectsMap(dialectsMap), fileLoc(fileLoc), 817 bytecodeVersion(bytecodeVersion), parserConfig(config) {} 818 819 /// Initialize the attribute and type information within the reader. 820 LogicalResult 821 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 822 ArrayRef<uint8_t> sectionData, 823 ArrayRef<uint8_t> offsetSectionData); 824 825 /// Resolve the attribute or type at the given index. Returns nullptr on 826 /// failure. 827 Attribute resolveAttribute(size_t index) { 828 return resolveEntry(attributes, index, "Attribute"); 829 } 830 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } 831 832 /// Parse a reference to an attribute or type using the given reader. 833 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { 834 uint64_t attrIdx; 835 if (failed(reader.parseVarInt(attrIdx))) 836 return failure(); 837 result = resolveAttribute(attrIdx); 838 return success(!!result); 839 } 840 LogicalResult parseOptionalAttribute(EncodingReader &reader, 841 Attribute &result) { 842 uint64_t attrIdx; 843 bool flag; 844 if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) 845 return failure(); 846 if (!flag) 847 return success(); 848 result = resolveAttribute(attrIdx); 849 return success(!!result); 850 } 851 852 LogicalResult parseType(EncodingReader &reader, Type &result) { 853 uint64_t typeIdx; 854 if (failed(reader.parseVarInt(typeIdx))) 855 return failure(); 856 result = resolveType(typeIdx); 857 return success(!!result); 858 } 859 860 template <typename T> 861 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 862 Attribute baseResult; 863 if (failed(parseAttribute(reader, baseResult))) 864 return failure(); 865 if ((result = dyn_cast<T>(baseResult))) 866 return success(); 867 return reader.emitError("expected attribute of type: ", 868 llvm::getTypeName<T>(), ", but got: ", baseResult); 869 } 870 871 private: 872 /// Resolve the given entry at `index`. 873 template <typename T> 874 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 875 StringRef entryType); 876 877 /// Parse an entry using the given reader that was encoded using the textual 878 /// assembly format. 879 template <typename T> 880 LogicalResult parseAsmEntry(T &result, EncodingReader &reader, 881 StringRef entryType); 882 883 /// Parse an entry using the given reader that was encoded using a custom 884 /// bytecode format. 885 template <typename T> 886 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, 887 StringRef entryType); 888 889 /// The string section reader used to resolve string references when parsing 890 /// custom encoded attribute/type entries. 891 StringSectionReader &stringReader; 892 893 /// The resource section reader used to resolve resource references when 894 /// parsing custom encoded attribute/type entries. 895 ResourceSectionReader &resourceReader; 896 897 /// The map of the loaded dialects used to retrieve dialect information, such 898 /// as the dialect version. 899 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 900 901 /// The set of attribute and type entries. 902 SmallVector<AttrEntry> attributes; 903 SmallVector<TypeEntry> types; 904 905 /// A location used for error emission. 906 Location fileLoc; 907 908 /// Current bytecode version being used. 909 uint64_t &bytecodeVersion; 910 911 /// Reference to the parser configuration. 912 const ParserConfig &parserConfig; 913 }; 914 915 class DialectReader : public DialectBytecodeReader { 916 public: 917 DialectReader(AttrTypeReader &attrTypeReader, 918 StringSectionReader &stringReader, 919 ResourceSectionReader &resourceReader, 920 const llvm::StringMap<BytecodeDialect *> &dialectsMap, 921 EncodingReader &reader, uint64_t &bytecodeVersion) 922 : attrTypeReader(attrTypeReader), stringReader(stringReader), 923 resourceReader(resourceReader), dialectsMap(dialectsMap), 924 reader(reader), bytecodeVersion(bytecodeVersion) {} 925 926 InFlightDiagnostic emitError(const Twine &msg) const override { 927 return reader.emitError(msg); 928 } 929 930 FailureOr<const DialectVersion *> 931 getDialectVersion(StringRef dialectName) const override { 932 // First check if the dialect is available in the map. 933 auto dialectEntry = dialectsMap.find(dialectName); 934 if (dialectEntry == dialectsMap.end()) 935 return failure(); 936 // If the dialect was found, try to load it. This will trigger reading the 937 // bytecode version from the version buffer if it wasn't already processed. 938 // Return failure if either of those two actions could not be completed. 939 if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || 940 dialectEntry->getValue()->loadedVersion == nullptr) 941 return failure(); 942 return dialectEntry->getValue()->loadedVersion.get(); 943 } 944 945 MLIRContext *getContext() const override { return getLoc().getContext(); } 946 947 uint64_t getBytecodeVersion() const override { return bytecodeVersion; } 948 949 DialectReader withEncodingReader(EncodingReader &encReader) const { 950 return DialectReader(attrTypeReader, stringReader, resourceReader, 951 dialectsMap, encReader, bytecodeVersion); 952 } 953 954 Location getLoc() const { return reader.getLoc(); } 955 956 //===--------------------------------------------------------------------===// 957 // IR 958 //===--------------------------------------------------------------------===// 959 960 LogicalResult readAttribute(Attribute &result) override { 961 return attrTypeReader.parseAttribute(reader, result); 962 } 963 LogicalResult readOptionalAttribute(Attribute &result) override { 964 return attrTypeReader.parseOptionalAttribute(reader, result); 965 } 966 LogicalResult readType(Type &result) override { 967 return attrTypeReader.parseType(reader, result); 968 } 969 970 FailureOr<AsmDialectResourceHandle> readResourceHandle() override { 971 AsmDialectResourceHandle handle; 972 if (failed(resourceReader.parseResourceHandle(reader, handle))) 973 return failure(); 974 return handle; 975 } 976 977 //===--------------------------------------------------------------------===// 978 // Primitives 979 //===--------------------------------------------------------------------===// 980 981 LogicalResult readVarInt(uint64_t &result) override { 982 return reader.parseVarInt(result); 983 } 984 985 LogicalResult readSignedVarInt(int64_t &result) override { 986 uint64_t unsignedResult; 987 if (failed(reader.parseSignedVarInt(unsignedResult))) 988 return failure(); 989 result = static_cast<int64_t>(unsignedResult); 990 return success(); 991 } 992 993 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { 994 // Small values are encoded using a single byte. 995 if (bitWidth <= 8) { 996 uint8_t value; 997 if (failed(reader.parseByte(value))) 998 return failure(); 999 return APInt(bitWidth, value); 1000 } 1001 1002 // Large values up to 64 bits are encoded using a single varint. 1003 if (bitWidth <= 64) { 1004 uint64_t value; 1005 if (failed(reader.parseSignedVarInt(value))) 1006 return failure(); 1007 return APInt(bitWidth, value); 1008 } 1009 1010 // Otherwise, for really big values we encode the array of active words in 1011 // the value. 1012 uint64_t numActiveWords; 1013 if (failed(reader.parseVarInt(numActiveWords))) 1014 return failure(); 1015 SmallVector<uint64_t, 4> words(numActiveWords); 1016 for (uint64_t i = 0; i < numActiveWords; ++i) 1017 if (failed(reader.parseSignedVarInt(words[i]))) 1018 return failure(); 1019 return APInt(bitWidth, words); 1020 } 1021 1022 FailureOr<APFloat> 1023 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { 1024 FailureOr<APInt> intVal = 1025 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); 1026 if (failed(intVal)) 1027 return failure(); 1028 return APFloat(semantics, *intVal); 1029 } 1030 1031 LogicalResult readString(StringRef &result) override { 1032 return stringReader.parseString(reader, result); 1033 } 1034 1035 LogicalResult readBlob(ArrayRef<char> &result) override { 1036 uint64_t dataSize; 1037 ArrayRef<uint8_t> data; 1038 if (failed(reader.parseVarInt(dataSize)) || 1039 failed(reader.parseBytes(dataSize, data))) 1040 return failure(); 1041 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), 1042 data.size()); 1043 return success(); 1044 } 1045 1046 LogicalResult readBool(bool &result) override { 1047 return reader.parseByte(result); 1048 } 1049 1050 private: 1051 AttrTypeReader &attrTypeReader; 1052 StringSectionReader &stringReader; 1053 ResourceSectionReader &resourceReader; 1054 const llvm::StringMap<BytecodeDialect *> &dialectsMap; 1055 EncodingReader &reader; 1056 uint64_t &bytecodeVersion; 1057 }; 1058 1059 /// Wraps the properties section and handles reading properties out of it. 1060 class PropertiesSectionReader { 1061 public: 1062 /// Initialize the properties section reader with the given section data. 1063 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { 1064 if (sectionData.empty()) 1065 return success(); 1066 EncodingReader propReader(sectionData, fileLoc); 1067 uint64_t count; 1068 if (failed(propReader.parseVarInt(count))) 1069 return failure(); 1070 // Parse the raw properties buffer. 1071 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) 1072 return failure(); 1073 1074 EncodingReader offsetsReader(propertiesBuffers, fileLoc); 1075 offsetTable.reserve(count); 1076 for (auto idx : llvm::seq<int64_t>(0, count)) { 1077 (void)idx; 1078 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); 1079 ArrayRef<uint8_t> rawProperties; 1080 uint64_t dataSize; 1081 if (failed(offsetsReader.parseVarInt(dataSize)) || 1082 failed(offsetsReader.parseBytes(dataSize, rawProperties))) 1083 return failure(); 1084 } 1085 if (!offsetsReader.empty()) 1086 return offsetsReader.emitError() 1087 << "Broken properties section: didn't exhaust the offsets table"; 1088 return success(); 1089 } 1090 1091 LogicalResult read(Location fileLoc, DialectReader &dialectReader, 1092 OperationName *opName, OperationState &opState) { 1093 uint64_t propertiesIdx; 1094 if (failed(dialectReader.readVarInt(propertiesIdx))) 1095 return failure(); 1096 if (propertiesIdx >= offsetTable.size()) 1097 return dialectReader.emitError("Properties idx out-of-bound for ") 1098 << opName->getStringRef(); 1099 size_t propertiesOffset = offsetTable[propertiesIdx]; 1100 if (propertiesIdx >= propertiesBuffers.size()) 1101 return dialectReader.emitError("Properties offset out-of-bound for ") 1102 << opName->getStringRef(); 1103 1104 // Acquire the sub-buffer that represent the requested properties. 1105 ArrayRef<char> rawProperties; 1106 { 1107 // "Seek" to the requested offset by getting a new reader with the right 1108 // sub-buffer. 1109 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), 1110 fileLoc); 1111 // Properties are stored as a sequence of {size + raw_data}. 1112 if (failed( 1113 dialectReader.withEncodingReader(reader).readBlob(rawProperties))) 1114 return failure(); 1115 } 1116 // Setup a new reader to read from the `rawProperties` sub-buffer. 1117 EncodingReader reader( 1118 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); 1119 DialectReader propReader = dialectReader.withEncodingReader(reader); 1120 1121 auto *iface = opName->getInterface<BytecodeOpInterface>(); 1122 if (iface) 1123 return iface->readProperties(propReader, opState); 1124 if (opName->isRegistered()) 1125 return propReader.emitError( 1126 "has properties but missing BytecodeOpInterface for ") 1127 << opName->getStringRef(); 1128 // Unregistered op are storing properties as an attribute. 1129 return propReader.readAttribute(opState.propertiesAttr); 1130 } 1131 1132 private: 1133 /// The properties buffer referenced within the bytecode file. 1134 ArrayRef<uint8_t> propertiesBuffers; 1135 1136 /// Table of offset in the buffer above. 1137 SmallVector<int64_t> offsetTable; 1138 }; 1139 } // namespace 1140 1141 LogicalResult AttrTypeReader::initialize( 1142 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, 1143 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { 1144 EncodingReader offsetReader(offsetSectionData, fileLoc); 1145 1146 // Parse the number of attribute and type entries. 1147 uint64_t numAttributes, numTypes; 1148 if (failed(offsetReader.parseVarInt(numAttributes)) || 1149 failed(offsetReader.parseVarInt(numTypes))) 1150 return failure(); 1151 attributes.resize(numAttributes); 1152 types.resize(numTypes); 1153 1154 // A functor used to accumulate the offsets for the entries in the given 1155 // range. 1156 uint64_t currentOffset = 0; 1157 auto parseEntries = [&](auto &&range) { 1158 size_t currentIndex = 0, endIndex = range.size(); 1159 1160 // Parse an individual entry. 1161 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { 1162 auto &entry = range[currentIndex++]; 1163 1164 uint64_t entrySize; 1165 if (failed(offsetReader.parseVarIntWithFlag(entrySize, 1166 entry.hasCustomEncoding))) 1167 return failure(); 1168 1169 // Verify that the offset is actually valid. 1170 if (currentOffset + entrySize > sectionData.size()) { 1171 return offsetReader.emitError( 1172 "Attribute or Type entry offset points past the end of section"); 1173 } 1174 1175 entry.data = sectionData.slice(currentOffset, entrySize); 1176 entry.dialect = dialect; 1177 currentOffset += entrySize; 1178 return success(); 1179 }; 1180 while (currentIndex != endIndex) 1181 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) 1182 return failure(); 1183 return success(); 1184 }; 1185 1186 // Process each of the attributes, and then the types. 1187 if (failed(parseEntries(attributes)) || failed(parseEntries(types))) 1188 return failure(); 1189 1190 // Ensure that we read everything from the section. 1191 if (!offsetReader.empty()) { 1192 return offsetReader.emitError( 1193 "unexpected trailing data in the Attribute/Type offset section"); 1194 } 1195 1196 return success(); 1197 } 1198 1199 template <typename T> 1200 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, 1201 StringRef entryType) { 1202 if (index >= entries.size()) { 1203 emitError(fileLoc) << "invalid " << entryType << " index: " << index; 1204 return {}; 1205 } 1206 1207 // If the entry has already been resolved, there is nothing left to do. 1208 Entry<T> &entry = entries[index]; 1209 if (entry.entry) 1210 return entry.entry; 1211 1212 // Parse the entry. 1213 EncodingReader reader(entry.data, fileLoc); 1214 1215 // Parse based on how the entry was encoded. 1216 if (entry.hasCustomEncoding) { 1217 if (failed(parseCustomEntry(entry, reader, entryType))) 1218 return T(); 1219 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { 1220 return T(); 1221 } 1222 1223 if (!reader.empty()) { 1224 reader.emitError("unexpected trailing bytes after " + entryType + " entry"); 1225 return T(); 1226 } 1227 return entry.entry; 1228 } 1229 1230 template <typename T> 1231 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, 1232 StringRef entryType) { 1233 StringRef asmStr; 1234 if (failed(reader.parseNullTerminatedString(asmStr))) 1235 return failure(); 1236 1237 // Invoke the MLIR assembly parser to parse the entry text. 1238 size_t numRead = 0; 1239 MLIRContext *context = fileLoc->getContext(); 1240 if constexpr (std::is_same_v<T, Type>) 1241 result = 1242 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); 1243 else 1244 result = ::parseAttribute(asmStr, context, Type(), &numRead, 1245 /*isKnownNullTerminated=*/true); 1246 if (!result) 1247 return failure(); 1248 1249 // Ensure there weren't dangling characters after the entry. 1250 if (numRead != asmStr.size()) { 1251 return reader.emitError("trailing characters found after ", entryType, 1252 " assembly format: ", asmStr.drop_front(numRead)); 1253 } 1254 return success(); 1255 } 1256 1257 template <typename T> 1258 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, 1259 EncodingReader &reader, 1260 StringRef entryType) { 1261 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, 1262 reader, bytecodeVersion); 1263 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) 1264 return failure(); 1265 1266 if constexpr (std::is_same_v<T, Type>) { 1267 // Try parsing with callbacks first if available. 1268 for (const auto &callback : 1269 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { 1270 if (failed( 1271 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1272 return failure(); 1273 // Early return if parsing was successful. 1274 if (!!entry.entry) 1275 return success(); 1276 1277 // Reset the reader if we failed to parse, so we can fall through the 1278 // other parsing functions. 1279 reader = EncodingReader(entry.data, reader.getLoc()); 1280 } 1281 } else { 1282 // Try parsing with callbacks first if available. 1283 for (const auto &callback : 1284 parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { 1285 if (failed( 1286 callback->read(dialectReader, entry.dialect->name, entry.entry))) 1287 return failure(); 1288 // Early return if parsing was successful. 1289 if (!!entry.entry) 1290 return success(); 1291 1292 // Reset the reader if we failed to parse, so we can fall through the 1293 // other parsing functions. 1294 reader = EncodingReader(entry.data, reader.getLoc()); 1295 } 1296 } 1297 1298 // Ensure that the dialect implements the bytecode interface. 1299 if (!entry.dialect->interface) { 1300 return reader.emitError("dialect '", entry.dialect->name, 1301 "' does not implement the bytecode interface"); 1302 } 1303 1304 if constexpr (std::is_same_v<T, Type>) 1305 entry.entry = entry.dialect->interface->readType(dialectReader); 1306 else 1307 entry.entry = entry.dialect->interface->readAttribute(dialectReader); 1308 1309 return success(!!entry.entry); 1310 } 1311 1312 //===----------------------------------------------------------------------===// 1313 // Bytecode Reader 1314 //===----------------------------------------------------------------------===// 1315 1316 /// This class is used to read a bytecode buffer and translate it into MLIR. 1317 class mlir::BytecodeReader::Impl { 1318 struct RegionReadState; 1319 using LazyLoadableOpsInfo = 1320 std::list<std::pair<Operation *, RegionReadState>>; 1321 using LazyLoadableOpsMap = 1322 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; 1323 1324 public: 1325 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, 1326 llvm::MemoryBufferRef buffer, 1327 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) 1328 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), 1329 attrTypeReader(stringReader, resourceReader, dialectsMap, version, 1330 fileLoc, config), 1331 // Use the builtin unrealized conversion cast operation to represent 1332 // forward references to values that aren't yet defined. 1333 forwardRefOpState(UnknownLoc::get(config.getContext()), 1334 "builtin.unrealized_conversion_cast", ValueRange(), 1335 NoneType::get(config.getContext())), 1336 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} 1337 1338 /// Read the bytecode defined within `buffer` into the given block. 1339 LogicalResult read(Block *block, 1340 llvm::function_ref<bool(Operation *)> lazyOps); 1341 1342 /// Return the number of ops that haven't been materialized yet. 1343 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } 1344 1345 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } 1346 1347 /// Materialize the provided operation, invoke the lazyOpsCallback on every 1348 /// newly found lazy operation. 1349 LogicalResult 1350 materialize(Operation *op, 1351 llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1352 this->lazyOpsCallback = lazyOpsCallback; 1353 auto resetlazyOpsCallback = 1354 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1355 auto it = lazyLoadableOpsMap.find(op); 1356 assert(it != lazyLoadableOpsMap.end() && 1357 "materialize called on non-materializable op"); 1358 return materialize(it); 1359 } 1360 1361 /// Materialize all operations. 1362 LogicalResult materializeAll() { 1363 while (!lazyLoadableOpsMap.empty()) { 1364 if (failed(materialize(lazyLoadableOpsMap.begin()))) 1365 return failure(); 1366 } 1367 return success(); 1368 } 1369 1370 /// Finalize the lazy-loading by calling back with every op that hasn't been 1371 /// materialized to let the client decide if the op should be deleted or 1372 /// materialized. The op is materialized if the callback returns true, deleted 1373 /// otherwise. 1374 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { 1375 while (!lazyLoadableOps.empty()) { 1376 Operation *op = lazyLoadableOps.begin()->first; 1377 if (shouldMaterialize(op)) { 1378 if (failed(materialize(lazyLoadableOpsMap.find(op)))) 1379 return failure(); 1380 continue; 1381 } 1382 op->dropAllReferences(); 1383 op->erase(); 1384 lazyLoadableOps.pop_front(); 1385 lazyLoadableOpsMap.erase(op); 1386 } 1387 return success(); 1388 } 1389 1390 private: 1391 LogicalResult materialize(LazyLoadableOpsMap::iterator it) { 1392 assert(it != lazyLoadableOpsMap.end() && 1393 "materialize called on non-materializable op"); 1394 valueScopes.emplace_back(); 1395 std::vector<RegionReadState> regionStack; 1396 regionStack.push_back(std::move(it->getSecond()->second)); 1397 lazyLoadableOps.erase(it->getSecond()); 1398 lazyLoadableOpsMap.erase(it); 1399 1400 while (!regionStack.empty()) 1401 if (failed(parseRegions(regionStack, regionStack.back()))) 1402 return failure(); 1403 return success(); 1404 } 1405 1406 /// Return the context for this config. 1407 MLIRContext *getContext() const { return config.getContext(); } 1408 1409 /// Parse the bytecode version. 1410 LogicalResult parseVersion(EncodingReader &reader); 1411 1412 //===--------------------------------------------------------------------===// 1413 // Dialect Section 1414 1415 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); 1416 1417 /// Parse an operation name reference using the given reader, and set the 1418 /// `wasRegistered` flag that indicates if the bytecode was produced by a 1419 /// context where opName was registered. 1420 FailureOr<OperationName> parseOpName(EncodingReader &reader, 1421 std::optional<bool> &wasRegistered); 1422 1423 //===--------------------------------------------------------------------===// 1424 // Attribute/Type Section 1425 1426 /// Parse an attribute or type using the given reader. 1427 template <typename T> 1428 LogicalResult parseAttribute(EncodingReader &reader, T &result) { 1429 return attrTypeReader.parseAttribute(reader, result); 1430 } 1431 LogicalResult parseType(EncodingReader &reader, Type &result) { 1432 return attrTypeReader.parseType(reader, result); 1433 } 1434 1435 //===--------------------------------------------------------------------===// 1436 // Resource Section 1437 1438 LogicalResult 1439 parseResourceSection(EncodingReader &reader, 1440 std::optional<ArrayRef<uint8_t>> resourceData, 1441 std::optional<ArrayRef<uint8_t>> resourceOffsetData); 1442 1443 //===--------------------------------------------------------------------===// 1444 // IR Section 1445 1446 /// This struct represents the current read state of a range of regions. This 1447 /// struct is used to enable iterative parsing of regions. 1448 struct RegionReadState { 1449 RegionReadState(Operation *op, EncodingReader *reader, 1450 bool isIsolatedFromAbove) 1451 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} 1452 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, 1453 bool isIsolatedFromAbove) 1454 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), 1455 isIsolatedFromAbove(isIsolatedFromAbove) {} 1456 1457 /// The current regions being read. 1458 MutableArrayRef<Region>::iterator curRegion, endRegion; 1459 /// This is the reader to use for this region, this pointer is pointing to 1460 /// the parent region reader unless the current region is IsolatedFromAbove, 1461 /// in which case the pointer is pointing to the `owningReader` which is a 1462 /// section dedicated to the current region. 1463 EncodingReader *reader; 1464 std::unique_ptr<EncodingReader> owningReader; 1465 1466 /// The number of values defined immediately within this region. 1467 unsigned numValues = 0; 1468 1469 /// The current blocks of the region being read. 1470 SmallVector<Block *> curBlocks; 1471 Region::iterator curBlock = {}; 1472 1473 /// The number of operations remaining to be read from the current block 1474 /// being read. 1475 uint64_t numOpsRemaining = 0; 1476 1477 /// A flag indicating if the regions being read are isolated from above. 1478 bool isIsolatedFromAbove = false; 1479 }; 1480 1481 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); 1482 LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, 1483 RegionReadState &readState); 1484 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, 1485 RegionReadState &readState, 1486 bool &isIsolatedFromAbove); 1487 1488 LogicalResult parseRegion(RegionReadState &readState); 1489 LogicalResult parseBlockHeader(EncodingReader &reader, 1490 RegionReadState &readState); 1491 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); 1492 1493 //===--------------------------------------------------------------------===// 1494 // Value Processing 1495 1496 /// Parse an operand reference using the given reader. Returns nullptr in the 1497 /// case of failure. 1498 Value parseOperand(EncodingReader &reader); 1499 1500 /// Sequentially define the given value range. 1501 LogicalResult defineValues(EncodingReader &reader, ValueRange values); 1502 1503 /// Create a value to use for a forward reference. 1504 Value createForwardRef(); 1505 1506 //===--------------------------------------------------------------------===// 1507 // Use-list order helpers 1508 1509 /// This struct is a simple storage that contains information required to 1510 /// reorder the use-list of a value with respect to the pre-order traversal 1511 /// ordering. 1512 struct UseListOrderStorage { 1513 UseListOrderStorage(bool isIndexPairEncoding, 1514 SmallVector<unsigned, 4> &&indices) 1515 : indices(std::move(indices)), 1516 isIndexPairEncoding(isIndexPairEncoding){}; 1517 /// The vector containing the information required to reorder the 1518 /// use-list of a value. 1519 SmallVector<unsigned, 4> indices; 1520 1521 /// Whether indices represent a pair of type `(src, dst)` or it is a direct 1522 /// indexing, such as `dst = order[src]`. 1523 bool isIndexPairEncoding; 1524 }; 1525 1526 /// Parse use-list order from bytecode for a range of values if available. The 1527 /// range is expected to be either a block argument or an op result range. On 1528 /// success, return a map of the position in the range and the use-list order 1529 /// encoding. The function assumes to know the size of the range it is 1530 /// processing. 1531 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; 1532 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, 1533 uint64_t rangeSize); 1534 1535 /// Shuffle the use-chain according to the order parsed. 1536 LogicalResult sortUseListOrder(Value value); 1537 1538 /// Recursively visit all the values defined within topLevelOp and sort the 1539 /// use-list orders according to the indices parsed. 1540 LogicalResult processUseLists(Operation *topLevelOp); 1541 1542 //===--------------------------------------------------------------------===// 1543 // Fields 1544 1545 /// This class represents a single value scope, in which a value scope is 1546 /// delimited by isolated from above regions. 1547 struct ValueScope { 1548 /// Push a new region state onto this scope, reserving enough values for 1549 /// those defined within the current region of the provided state. 1550 void push(RegionReadState &readState) { 1551 nextValueIDs.push_back(values.size()); 1552 values.resize(values.size() + readState.numValues); 1553 } 1554 1555 /// Pop the values defined for the current region within the provided region 1556 /// state. 1557 void pop(RegionReadState &readState) { 1558 values.resize(values.size() - readState.numValues); 1559 nextValueIDs.pop_back(); 1560 } 1561 1562 /// The set of values defined in this scope. 1563 std::vector<Value> values; 1564 1565 /// The ID for the next defined value for each region current being 1566 /// processed in this scope. 1567 SmallVector<unsigned, 4> nextValueIDs; 1568 }; 1569 1570 /// The configuration of the parser. 1571 const ParserConfig &config; 1572 1573 /// A location to use when emitting errors. 1574 Location fileLoc; 1575 1576 /// Flag that indicates if lazyloading is enabled. 1577 bool lazyLoading; 1578 1579 /// Keep track of operations that have been lazy loaded (their regions haven't 1580 /// been materialized), along with the `RegionReadState` that allows to 1581 /// lazy-load the regions nested under the operation. 1582 LazyLoadableOpsInfo lazyLoadableOps; 1583 LazyLoadableOpsMap lazyLoadableOpsMap; 1584 llvm::function_ref<bool(Operation *)> lazyOpsCallback; 1585 1586 /// The reader used to process attribute and types within the bytecode. 1587 AttrTypeReader attrTypeReader; 1588 1589 /// The version of the bytecode being read. 1590 uint64_t version = 0; 1591 1592 /// The producer of the bytecode being read. 1593 StringRef producer; 1594 1595 /// The table of IR units referenced within the bytecode file. 1596 SmallVector<std::unique_ptr<BytecodeDialect>> dialects; 1597 llvm::StringMap<BytecodeDialect *> dialectsMap; 1598 SmallVector<BytecodeOperationName> opNames; 1599 1600 /// The reader used to process resources within the bytecode. 1601 ResourceSectionReader resourceReader; 1602 1603 /// Worklist of values with custom use-list orders to process before the end 1604 /// of the parsing. 1605 DenseMap<void *, UseListOrderStorage> valueToUseListMap; 1606 1607 /// The table of strings referenced within the bytecode file. 1608 StringSectionReader stringReader; 1609 1610 /// The table of properties referenced by the operation in the bytecode file. 1611 PropertiesSectionReader propertiesReader; 1612 1613 /// The current set of available IR value scopes. 1614 std::vector<ValueScope> valueScopes; 1615 1616 /// The global pre-order operation ordering. 1617 DenseMap<Operation *, unsigned> operationIDs; 1618 1619 /// A block containing the set of operations defined to create forward 1620 /// references. 1621 Block forwardRefOps; 1622 1623 /// A block containing previously created, and no longer used, forward 1624 /// reference operations. 1625 Block openForwardRefOps; 1626 1627 /// An operation state used when instantiating forward references. 1628 OperationState forwardRefOpState; 1629 1630 /// Reference to the input buffer. 1631 llvm::MemoryBufferRef buffer; 1632 1633 /// The optional owning source manager, which when present may be used to 1634 /// extend the lifetime of the input buffer. 1635 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; 1636 }; 1637 1638 LogicalResult BytecodeReader::Impl::read( 1639 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 1640 EncodingReader reader(buffer.getBuffer(), fileLoc); 1641 this->lazyOpsCallback = lazyOpsCallback; 1642 auto resetlazyOpsCallback = 1643 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); 1644 1645 // Skip over the bytecode header, this should have already been checked. 1646 if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) 1647 return failure(); 1648 // Parse the bytecode version and producer. 1649 if (failed(parseVersion(reader)) || 1650 failed(reader.parseNullTerminatedString(producer))) 1651 return failure(); 1652 1653 // Add a diagnostic handler that attaches a note that includes the original 1654 // producer of the bytecode. 1655 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { 1656 diag.attachNote() << "in bytecode version " << version 1657 << " produced by: " << producer; 1658 return failure(); 1659 }); 1660 1661 // Parse the raw data for each of the top-level sections of the bytecode. 1662 std::optional<ArrayRef<uint8_t>> 1663 sectionDatas[bytecode::Section::kNumSections]; 1664 while (!reader.empty()) { 1665 // Read the next section from the bytecode. 1666 bytecode::Section::ID sectionID; 1667 ArrayRef<uint8_t> sectionData; 1668 if (failed(reader.parseSection(sectionID, sectionData))) 1669 return failure(); 1670 1671 // Check for duplicate sections, we only expect one instance of each. 1672 if (sectionDatas[sectionID]) { 1673 return reader.emitError("duplicate top-level section: ", 1674 ::toString(sectionID)); 1675 } 1676 sectionDatas[sectionID] = sectionData; 1677 } 1678 // Check that all of the required sections were found. 1679 for (int i = 0; i < bytecode::Section::kNumSections; ++i) { 1680 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); 1681 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { 1682 return reader.emitError("missing data for top-level section: ", 1683 ::toString(sectionID)); 1684 } 1685 } 1686 1687 // Process the string section first. 1688 if (failed(stringReader.initialize( 1689 fileLoc, *sectionDatas[bytecode::Section::kString]))) 1690 return failure(); 1691 1692 // Process the properties section. 1693 if (sectionDatas[bytecode::Section::kProperties] && 1694 failed(propertiesReader.initialize( 1695 fileLoc, *sectionDatas[bytecode::Section::kProperties]))) 1696 return failure(); 1697 1698 // Process the dialect section. 1699 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) 1700 return failure(); 1701 1702 // Process the resource section if present. 1703 if (failed(parseResourceSection( 1704 reader, sectionDatas[bytecode::Section::kResource], 1705 sectionDatas[bytecode::Section::kResourceOffset]))) 1706 return failure(); 1707 1708 // Process the attribute and type section. 1709 if (failed(attrTypeReader.initialize( 1710 dialects, *sectionDatas[bytecode::Section::kAttrType], 1711 *sectionDatas[bytecode::Section::kAttrTypeOffset]))) 1712 return failure(); 1713 1714 // Finally, process the IR section. 1715 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); 1716 } 1717 1718 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { 1719 if (failed(reader.parseVarInt(version))) 1720 return failure(); 1721 1722 // Validate the bytecode version. 1723 uint64_t currentVersion = bytecode::kVersion; 1724 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; 1725 if (version < minSupportedVersion) { 1726 return reader.emitError("bytecode version ", version, 1727 " is older than the current version of ", 1728 currentVersion, ", and upgrade is not supported"); 1729 } 1730 if (version > currentVersion) { 1731 return reader.emitError("bytecode version ", version, 1732 " is newer than the current version ", 1733 currentVersion); 1734 } 1735 // Override any request to lazy-load if the bytecode version is too old. 1736 if (version < bytecode::kLazyLoading) 1737 lazyLoading = false; 1738 return success(); 1739 } 1740 1741 //===----------------------------------------------------------------------===// 1742 // Dialect Section 1743 1744 LogicalResult BytecodeDialect::load(const DialectReader &reader, 1745 MLIRContext *ctx) { 1746 if (dialect) 1747 return success(); 1748 Dialect *loadedDialect = ctx->getOrLoadDialect(name); 1749 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { 1750 return reader.emitError("dialect '") 1751 << name 1752 << "' is unknown. If this is intended, please call " 1753 "allowUnregisteredDialects() on the MLIRContext, or use " 1754 "-allow-unregistered-dialect with the MLIR tool used."; 1755 } 1756 dialect = loadedDialect; 1757 1758 // If the dialect was actually loaded, check to see if it has a bytecode 1759 // interface. 1760 if (loadedDialect) 1761 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); 1762 if (!versionBuffer.empty()) { 1763 if (!interface) 1764 return reader.emitError("dialect '") 1765 << name 1766 << "' does not implement the bytecode interface, " 1767 "but found a version entry"; 1768 EncodingReader encReader(versionBuffer, reader.getLoc()); 1769 DialectReader versionReader = reader.withEncodingReader(encReader); 1770 loadedVersion = interface->readVersion(versionReader); 1771 if (!loadedVersion) 1772 return failure(); 1773 } 1774 return success(); 1775 } 1776 1777 LogicalResult 1778 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { 1779 EncodingReader sectionReader(sectionData, fileLoc); 1780 1781 // Parse the number of dialects in the section. 1782 uint64_t numDialects; 1783 if (failed(sectionReader.parseVarInt(numDialects))) 1784 return failure(); 1785 dialects.resize(numDialects); 1786 1787 // Parse each of the dialects. 1788 for (uint64_t i = 0; i < numDialects; ++i) { 1789 dialects[i] = std::make_unique<BytecodeDialect>(); 1790 /// Before version kDialectVersioning, there wasn't any versioning available 1791 /// for dialects, and the entryIdx represent the string itself. 1792 if (version < bytecode::kDialectVersioning) { 1793 if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) 1794 return failure(); 1795 continue; 1796 } 1797 1798 // Parse ID representing dialect and version. 1799 uint64_t dialectNameIdx; 1800 bool versionAvailable; 1801 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, 1802 versionAvailable))) 1803 return failure(); 1804 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, 1805 dialects[i]->name))) 1806 return failure(); 1807 if (versionAvailable) { 1808 bytecode::Section::ID sectionID; 1809 if (failed(sectionReader.parseSection(sectionID, 1810 dialects[i]->versionBuffer))) 1811 return failure(); 1812 if (sectionID != bytecode::Section::kDialectVersions) { 1813 emitError(fileLoc, "expected dialect version section"); 1814 return failure(); 1815 } 1816 } 1817 dialectsMap[dialects[i]->name] = dialects[i].get(); 1818 } 1819 1820 // Parse the operation names, which are grouped by dialect. 1821 auto parseOpName = [&](BytecodeDialect *dialect) { 1822 StringRef opName; 1823 std::optional<bool> wasRegistered; 1824 // Prior to version kNativePropertiesEncoding, the information about wheter 1825 // an op was registered or not wasn't encoded. 1826 if (version < bytecode::kNativePropertiesEncoding) { 1827 if (failed(stringReader.parseString(sectionReader, opName))) 1828 return failure(); 1829 } else { 1830 bool wasRegisteredFlag; 1831 if (failed(stringReader.parseStringWithFlag(sectionReader, opName, 1832 wasRegisteredFlag))) 1833 return failure(); 1834 wasRegistered = wasRegisteredFlag; 1835 } 1836 opNames.emplace_back(dialect, opName, wasRegistered); 1837 return success(); 1838 }; 1839 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation 1840 // where the number of ops are known. 1841 if (version >= bytecode::kElideUnknownBlockArgLocation) { 1842 uint64_t numOps; 1843 if (failed(sectionReader.parseVarInt(numOps))) 1844 return failure(); 1845 opNames.reserve(numOps); 1846 } 1847 while (!sectionReader.empty()) 1848 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) 1849 return failure(); 1850 return success(); 1851 } 1852 1853 FailureOr<OperationName> 1854 BytecodeReader::Impl::parseOpName(EncodingReader &reader, 1855 std::optional<bool> &wasRegistered) { 1856 BytecodeOperationName *opName = nullptr; 1857 if (failed(parseEntry(reader, opNames, opName, "operation name"))) 1858 return failure(); 1859 wasRegistered = opName->wasRegistered; 1860 // Check to see if this operation name has already been resolved. If we 1861 // haven't, load the dialect and build the operation name. 1862 if (!opName->opName) { 1863 // Load the dialect and its version. 1864 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1865 dialectsMap, reader, version); 1866 if (failed(opName->dialect->load(dialectReader, getContext()))) 1867 return failure(); 1868 // If the opName is empty, this is because we use to accept names such as 1869 // `foo` without any `.` separator. We shouldn't tolerate this in textual 1870 // format anymore but for now we'll be backward compatible. This can only 1871 // happen with unregistered dialects. 1872 if (opName->name.empty()) { 1873 if (opName->dialect->getLoadedDialect()) 1874 return emitError(fileLoc) << "has an empty opname for dialect '" 1875 << opName->dialect->name << "'\n"; 1876 1877 opName->opName.emplace(opName->dialect->name, getContext()); 1878 } else { 1879 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), 1880 getContext()); 1881 } 1882 } 1883 return *opName->opName; 1884 } 1885 1886 //===----------------------------------------------------------------------===// 1887 // Resource Section 1888 1889 LogicalResult BytecodeReader::Impl::parseResourceSection( 1890 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, 1891 std::optional<ArrayRef<uint8_t>> resourceOffsetData) { 1892 // Ensure both sections are either present or not. 1893 if (resourceData.has_value() != resourceOffsetData.has_value()) { 1894 if (resourceOffsetData) 1895 return emitError(fileLoc, "unexpected resource offset section when " 1896 "resource section is not present"); 1897 return emitError( 1898 fileLoc, 1899 "expected resource offset section when resource section is present"); 1900 } 1901 1902 // If the resource sections are absent, there is nothing to do. 1903 if (!resourceData) 1904 return success(); 1905 1906 // Initialize the resource reader with the resource sections. 1907 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 1908 dialectsMap, reader, version); 1909 return resourceReader.initialize(fileLoc, config, dialects, stringReader, 1910 *resourceData, *resourceOffsetData, 1911 dialectReader, bufferOwnerRef); 1912 } 1913 1914 //===----------------------------------------------------------------------===// 1915 // UseListOrder Helpers 1916 1917 FailureOr<BytecodeReader::Impl::UseListMapT> 1918 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, 1919 uint64_t numResults) { 1920 BytecodeReader::Impl::UseListMapT map; 1921 uint64_t numValuesToRead = 1; 1922 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) 1923 return failure(); 1924 1925 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { 1926 uint64_t resultIdx = 0; 1927 if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) 1928 return failure(); 1929 1930 uint64_t numValues; 1931 bool indexPairEncoding; 1932 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) 1933 return failure(); 1934 1935 SmallVector<unsigned, 4> useListOrders; 1936 for (size_t idx = 0; idx < numValues; idx++) { 1937 uint64_t index; 1938 if (failed(reader.parseVarInt(index))) 1939 return failure(); 1940 useListOrders.push_back(index); 1941 } 1942 1943 // Store in a map the result index 1944 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, 1945 std::move(useListOrders))); 1946 } 1947 1948 return map; 1949 } 1950 1951 /// Sorts each use according to the order specified in the use-list parsed. If 1952 /// the custom use-list is not found, this means that the order needs to be 1953 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie 1954 /// on the same operation, the order will follow the reverse operand number 1955 /// ordering. 1956 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { 1957 // Early return for trivial use-lists. 1958 if (value.use_empty() || value.hasOneUse()) 1959 return success(); 1960 1961 bool hasIncomingOrder = 1962 valueToUseListMap.contains(value.getAsOpaquePointer()); 1963 1964 // Compute the current order of the use-list with respect to the global 1965 // ordering. Detect if the order is already sorted while doing so. 1966 bool alreadySorted = true; 1967 auto &firstUse = *value.use_begin(); 1968 uint64_t prevID = 1969 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); 1970 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; 1971 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { 1972 uint64_t currentID = bytecode::getUseID( 1973 item.value(), operationIDs.at(item.value().getOwner())); 1974 alreadySorted &= prevID > currentID; 1975 currentOrder.push_back({item.index(), currentID}); 1976 prevID = currentID; 1977 } 1978 1979 // If the order is already sorted, and there wasn't a custom order to apply 1980 // from the bytecode file, we are done. 1981 if (alreadySorted && !hasIncomingOrder) 1982 return success(); 1983 1984 // If not already sorted, sort the indices of the current order by descending 1985 // useIDs. 1986 if (!alreadySorted) 1987 std::sort( 1988 currentOrder.begin(), currentOrder.end(), 1989 [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); 1990 1991 if (!hasIncomingOrder) { 1992 // If the bytecode file did not contain any custom use-list order, it means 1993 // that the order was descending useID. Hence, shuffle by the first index 1994 // of the `currentOrder` pair. 1995 SmallVector<unsigned> shuffle = SmallVector<unsigned>( 1996 llvm::map_range(currentOrder, [&](auto item) { return item.first; })); 1997 value.shuffleUseList(shuffle); 1998 return success(); 1999 } 2000 2001 // Pull the custom order info from the map. 2002 UseListOrderStorage customOrder = 2003 valueToUseListMap.at(value.getAsOpaquePointer()); 2004 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); 2005 uint64_t numUses = 2006 std::distance(value.getUses().begin(), value.getUses().end()); 2007 2008 // If the encoding was a pair of indices `(src, dst)` for every permutation, 2009 // reconstruct the shuffle vector for every use. Initialize the shuffle vector 2010 // as identity, and then apply the mapping encoded in the indices. 2011 if (customOrder.isIndexPairEncoding) { 2012 // Return failure if the number of indices was not representing pairs. 2013 if (shuffle.size() & 1) 2014 return failure(); 2015 2016 SmallVector<unsigned, 4> newShuffle(numUses); 2017 size_t idx = 0; 2018 std::iota(newShuffle.begin(), newShuffle.end(), idx); 2019 for (idx = 0; idx < shuffle.size(); idx += 2) 2020 newShuffle[shuffle[idx]] = shuffle[idx + 1]; 2021 2022 shuffle = std::move(newShuffle); 2023 } 2024 2025 // Make sure that the indices represent a valid mapping. That is, the sum of 2026 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no 2027 // duplicates are allowed in the list. 2028 DenseSet<unsigned> set; 2029 uint64_t accumulator = 0; 2030 for (const auto &elem : shuffle) { 2031 if (set.contains(elem)) 2032 return failure(); 2033 accumulator += elem; 2034 set.insert(elem); 2035 } 2036 if (numUses != shuffle.size() || 2037 accumulator != (((numUses - 1) * numUses) >> 1)) 2038 return failure(); 2039 2040 // Apply the current ordering map onto the shuffle vector to get the final 2041 // use-list sorting indices before shuffling. 2042 shuffle = SmallVector<unsigned, 4>(llvm::map_range( 2043 currentOrder, [&](auto item) { return shuffle[item.first]; })); 2044 value.shuffleUseList(shuffle); 2045 return success(); 2046 } 2047 2048 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { 2049 // Precompute operation IDs according to the pre-order walk of the IR. We 2050 // can't do this while parsing since parseRegions ordering is not strictly 2051 // equal to the pre-order walk. 2052 unsigned operationID = 0; 2053 topLevelOp->walk<mlir::WalkOrder::PreOrder>( 2054 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); 2055 2056 auto blockWalk = topLevelOp->walk([this](Block *block) { 2057 for (auto arg : block->getArguments()) 2058 if (failed(sortUseListOrder(arg))) 2059 return WalkResult::interrupt(); 2060 return WalkResult::advance(); 2061 }); 2062 2063 auto resultWalk = topLevelOp->walk([this](Operation *op) { 2064 for (auto result : op->getResults()) 2065 if (failed(sortUseListOrder(result))) 2066 return WalkResult::interrupt(); 2067 return WalkResult::advance(); 2068 }); 2069 2070 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); 2071 } 2072 2073 //===----------------------------------------------------------------------===// 2074 // IR Section 2075 2076 LogicalResult 2077 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, 2078 Block *block) { 2079 EncodingReader reader(sectionData, fileLoc); 2080 2081 // A stack of operation regions currently being read from the bytecode. 2082 std::vector<RegionReadState> regionStack; 2083 2084 // Parse the top-level block using a temporary module operation. 2085 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); 2086 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); 2087 regionStack.back().curBlocks.push_back(moduleOp->getBody()); 2088 regionStack.back().curBlock = regionStack.back().curRegion->begin(); 2089 if (failed(parseBlockHeader(reader, regionStack.back()))) 2090 return failure(); 2091 valueScopes.emplace_back(); 2092 valueScopes.back().push(regionStack.back()); 2093 2094 // Iteratively parse regions until everything has been resolved. 2095 while (!regionStack.empty()) 2096 if (failed(parseRegions(regionStack, regionStack.back()))) 2097 return failure(); 2098 if (!forwardRefOps.empty()) { 2099 return reader.emitError( 2100 "not all forward unresolved forward operand references"); 2101 } 2102 2103 // Sort use-lists according to what specified in bytecode. 2104 if (failed(processUseLists(*moduleOp))) 2105 return reader.emitError( 2106 "parsed use-list orders were invalid and could not be applied"); 2107 2108 // Resolve dialect version. 2109 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { 2110 // Parsing is complete, give an opportunity to each dialect to visit the 2111 // IR and perform upgrades. 2112 if (!byteCodeDialect->loadedVersion) 2113 continue; 2114 if (byteCodeDialect->interface && 2115 failed(byteCodeDialect->interface->upgradeFromVersion( 2116 *moduleOp, *byteCodeDialect->loadedVersion))) 2117 return failure(); 2118 } 2119 2120 // Verify that the parsed operations are valid. 2121 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) 2122 return failure(); 2123 2124 // Splice the parsed operations over to the provided top-level block. 2125 auto &parsedOps = moduleOp->getBody()->getOperations(); 2126 auto &destOps = block->getOperations(); 2127 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); 2128 return success(); 2129 } 2130 2131 LogicalResult 2132 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, 2133 RegionReadState &readState) { 2134 // Process regions, blocks, and operations until the end or if a nested 2135 // region is encountered. In this case we push a new state in regionStack and 2136 // return, the processing of the current region will resume afterward. 2137 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { 2138 // If the current block hasn't been setup yet, parse the header for this 2139 // region. The current block is already setup when this function was 2140 // interrupted to recurse down in a nested region and we resume the current 2141 // block after processing the nested region. 2142 if (readState.curBlock == Region::iterator()) { 2143 if (failed(parseRegion(readState))) 2144 return failure(); 2145 2146 // If the region is empty, there is nothing to more to do. 2147 if (readState.curRegion->empty()) 2148 continue; 2149 } 2150 2151 // Parse the blocks within the region. 2152 EncodingReader &reader = *readState.reader; 2153 do { 2154 while (readState.numOpsRemaining--) { 2155 // Read in the next operation. We don't read its regions directly, we 2156 // handle those afterwards as necessary. 2157 bool isIsolatedFromAbove = false; 2158 FailureOr<Operation *> op = 2159 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); 2160 if (failed(op)) 2161 return failure(); 2162 2163 // If the op has regions, add it to the stack for processing and return: 2164 // we stop the processing of the current region and resume it after the 2165 // inner one is completed. Unless LazyLoading is activated in which case 2166 // nested region parsing is delayed. 2167 if ((*op)->getNumRegions()) { 2168 RegionReadState childState(*op, &reader, isIsolatedFromAbove); 2169 2170 // Isolated regions are encoded as a section in version 2 and above. 2171 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { 2172 bytecode::Section::ID sectionID; 2173 ArrayRef<uint8_t> sectionData; 2174 if (failed(reader.parseSection(sectionID, sectionData))) 2175 return failure(); 2176 if (sectionID != bytecode::Section::kIR) 2177 return emitError(fileLoc, "expected IR section for region"); 2178 childState.owningReader = 2179 std::make_unique<EncodingReader>(sectionData, fileLoc); 2180 childState.reader = childState.owningReader.get(); 2181 2182 // If the user has a callback set, they have the opportunity to 2183 // control lazyloading as we go. 2184 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) { 2185 lazyLoadableOps.emplace_back(*op, std::move(childState)); 2186 lazyLoadableOpsMap.try_emplace(*op, 2187 std::prev(lazyLoadableOps.end())); 2188 continue; 2189 } 2190 } 2191 regionStack.push_back(std::move(childState)); 2192 2193 // If the op is isolated from above, push a new value scope. 2194 if (isIsolatedFromAbove) 2195 valueScopes.emplace_back(); 2196 return success(); 2197 } 2198 } 2199 2200 // Move to the next block of the region. 2201 if (++readState.curBlock == readState.curRegion->end()) 2202 break; 2203 if (failed(parseBlockHeader(reader, readState))) 2204 return failure(); 2205 } while (true); 2206 2207 // Reset the current block and any values reserved for this region. 2208 readState.curBlock = {}; 2209 valueScopes.back().pop(readState); 2210 } 2211 2212 // When the regions have been fully parsed, pop them off of the read stack. If 2213 // the regions were isolated from above, we also pop the last value scope. 2214 if (readState.isIsolatedFromAbove) { 2215 assert(!valueScopes.empty() && "Expect a valueScope after reading region"); 2216 valueScopes.pop_back(); 2217 } 2218 assert(!regionStack.empty() && "Expect a regionStack after reading region"); 2219 regionStack.pop_back(); 2220 return success(); 2221 } 2222 2223 FailureOr<Operation *> 2224 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, 2225 RegionReadState &readState, 2226 bool &isIsolatedFromAbove) { 2227 // Parse the name of the operation. 2228 std::optional<bool> wasRegistered; 2229 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); 2230 if (failed(opName)) 2231 return failure(); 2232 2233 // Parse the operation mask, which indicates which components of the operation 2234 // are present. 2235 uint8_t opMask; 2236 if (failed(reader.parseByte(opMask))) 2237 return failure(); 2238 2239 /// Parse the location. 2240 LocationAttr opLoc; 2241 if (failed(parseAttribute(reader, opLoc))) 2242 return failure(); 2243 2244 // With the location and name resolved, we can start building the operation 2245 // state. 2246 OperationState opState(opLoc, *opName); 2247 2248 // Parse the attributes of the operation. 2249 if (opMask & bytecode::OpEncodingMask::kHasAttrs) { 2250 DictionaryAttr dictAttr; 2251 if (failed(parseAttribute(reader, dictAttr))) 2252 return failure(); 2253 opState.attributes = dictAttr; 2254 } 2255 2256 if (opMask & bytecode::OpEncodingMask::kHasProperties) { 2257 // kHasProperties wasn't emitted in older bytecode, we should never get 2258 // there without also having the `wasRegistered` flag available. 2259 if (!wasRegistered) 2260 return emitError(fileLoc, 2261 "Unexpected missing `wasRegistered` opname flag at " 2262 "bytecode version ") 2263 << version << " with properties."; 2264 // When an operation is emitted without being registered, the properties are 2265 // stored as an attribute. Otherwise the op must implement the bytecode 2266 // interface and control the serialization. 2267 if (wasRegistered) { 2268 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, 2269 dialectsMap, reader, version); 2270 if (failed( 2271 propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) 2272 return failure(); 2273 } else { 2274 // If the operation wasn't registered when it was emitted, the properties 2275 // was serialized as an attribute. 2276 if (failed(parseAttribute(reader, opState.propertiesAttr))) 2277 return failure(); 2278 } 2279 } 2280 2281 /// Parse the results of the operation. 2282 if (opMask & bytecode::OpEncodingMask::kHasResults) { 2283 uint64_t numResults; 2284 if (failed(reader.parseVarInt(numResults))) 2285 return failure(); 2286 opState.types.resize(numResults); 2287 for (int i = 0, e = numResults; i < e; ++i) 2288 if (failed(parseType(reader, opState.types[i]))) 2289 return failure(); 2290 } 2291 2292 /// Parse the operands of the operation. 2293 if (opMask & bytecode::OpEncodingMask::kHasOperands) { 2294 uint64_t numOperands; 2295 if (failed(reader.parseVarInt(numOperands))) 2296 return failure(); 2297 opState.operands.resize(numOperands); 2298 for (int i = 0, e = numOperands; i < e; ++i) 2299 if (!(opState.operands[i] = parseOperand(reader))) 2300 return failure(); 2301 } 2302 2303 /// Parse the successors of the operation. 2304 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { 2305 uint64_t numSuccs; 2306 if (failed(reader.parseVarInt(numSuccs))) 2307 return failure(); 2308 opState.successors.resize(numSuccs); 2309 for (int i = 0, e = numSuccs; i < e; ++i) { 2310 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], 2311 "successor"))) 2312 return failure(); 2313 } 2314 } 2315 2316 /// Parse the use-list orders for the results of the operation. Use-list 2317 /// orders are available since version 3 of the bytecode. 2318 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; 2319 if (version >= bytecode::kUseListOrdering && 2320 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { 2321 size_t numResults = opState.types.size(); 2322 auto parseResult = parseUseListOrderForRange(reader, numResults); 2323 if (failed(parseResult)) 2324 return failure(); 2325 resultIdxToUseListMap = std::move(*parseResult); 2326 } 2327 2328 /// Parse the regions of the operation. 2329 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { 2330 uint64_t numRegions; 2331 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) 2332 return failure(); 2333 2334 opState.regions.reserve(numRegions); 2335 for (int i = 0, e = numRegions; i < e; ++i) 2336 opState.regions.push_back(std::make_unique<Region>()); 2337 } 2338 2339 // Create the operation at the back of the current block. 2340 Operation *op = Operation::create(opState); 2341 readState.curBlock->push_back(op); 2342 2343 // If the operation had results, update the value references. 2344 if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) 2345 return failure(); 2346 2347 /// Store a map for every value that received a custom use-list order from the 2348 /// bytecode file. 2349 if (resultIdxToUseListMap.has_value()) { 2350 for (size_t idx = 0; idx < op->getNumResults(); idx++) { 2351 if (resultIdxToUseListMap->contains(idx)) { 2352 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), 2353 resultIdxToUseListMap->at(idx)); 2354 } 2355 } 2356 } 2357 return op; 2358 } 2359 2360 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { 2361 EncodingReader &reader = *readState.reader; 2362 2363 // Parse the number of blocks in the region. 2364 uint64_t numBlocks; 2365 if (failed(reader.parseVarInt(numBlocks))) 2366 return failure(); 2367 2368 // If the region is empty, there is nothing else to do. 2369 if (numBlocks == 0) 2370 return success(); 2371 2372 // Parse the number of values defined in this region. 2373 uint64_t numValues; 2374 if (failed(reader.parseVarInt(numValues))) 2375 return failure(); 2376 readState.numValues = numValues; 2377 2378 // Create the blocks within this region. We do this before processing so that 2379 // we can rely on the blocks existing when creating operations. 2380 readState.curBlocks.clear(); 2381 readState.curBlocks.reserve(numBlocks); 2382 for (uint64_t i = 0; i < numBlocks; ++i) { 2383 readState.curBlocks.push_back(new Block()); 2384 readState.curRegion->push_back(readState.curBlocks.back()); 2385 } 2386 2387 // Prepare the current value scope for this region. 2388 valueScopes.back().push(readState); 2389 2390 // Parse the entry block of the region. 2391 readState.curBlock = readState.curRegion->begin(); 2392 return parseBlockHeader(reader, readState); 2393 } 2394 2395 LogicalResult 2396 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, 2397 RegionReadState &readState) { 2398 bool hasArgs; 2399 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) 2400 return failure(); 2401 2402 // Parse the arguments of the block. 2403 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) 2404 return failure(); 2405 2406 // Uselist orders are available since version 3 of the bytecode. 2407 if (version < bytecode::kUseListOrdering) 2408 return success(); 2409 2410 uint8_t hasUseListOrders = 0; 2411 if (hasArgs && failed(reader.parseByte(hasUseListOrders))) 2412 return failure(); 2413 2414 if (!hasUseListOrders) 2415 return success(); 2416 2417 Block &blk = *readState.curBlock; 2418 auto argIdxToUseListMap = 2419 parseUseListOrderForRange(reader, blk.getNumArguments()); 2420 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) 2421 return failure(); 2422 2423 for (size_t idx = 0; idx < blk.getNumArguments(); idx++) 2424 if (argIdxToUseListMap->contains(idx)) 2425 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), 2426 argIdxToUseListMap->at(idx)); 2427 2428 // We don't parse the operations of the block here, that's done elsewhere. 2429 return success(); 2430 } 2431 2432 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, 2433 Block *block) { 2434 // Parse the value ID for the first argument, and the number of arguments. 2435 uint64_t numArgs; 2436 if (failed(reader.parseVarInt(numArgs))) 2437 return failure(); 2438 2439 SmallVector<Type> argTypes; 2440 SmallVector<Location> argLocs; 2441 argTypes.reserve(numArgs); 2442 argLocs.reserve(numArgs); 2443 2444 Location unknownLoc = UnknownLoc::get(config.getContext()); 2445 while (numArgs--) { 2446 Type argType; 2447 LocationAttr argLoc = unknownLoc; 2448 if (version >= bytecode::kElideUnknownBlockArgLocation) { 2449 // Parse the type with hasLoc flag to determine if it has type. 2450 uint64_t typeIdx; 2451 bool hasLoc; 2452 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || 2453 !(argType = attrTypeReader.resolveType(typeIdx))) 2454 return failure(); 2455 if (hasLoc && failed(parseAttribute(reader, argLoc))) 2456 return failure(); 2457 } else { 2458 // All args has type and location. 2459 if (failed(parseType(reader, argType)) || 2460 failed(parseAttribute(reader, argLoc))) 2461 return failure(); 2462 } 2463 argTypes.push_back(argType); 2464 argLocs.push_back(argLoc); 2465 } 2466 block->addArguments(argTypes, argLocs); 2467 return defineValues(reader, block->getArguments()); 2468 } 2469 2470 //===----------------------------------------------------------------------===// 2471 // Value Processing 2472 2473 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { 2474 std::vector<Value> &values = valueScopes.back().values; 2475 Value *value = nullptr; 2476 if (failed(parseEntry(reader, values, value, "value"))) 2477 return Value(); 2478 2479 // Create a new forward reference if necessary. 2480 if (!*value) 2481 *value = createForwardRef(); 2482 return *value; 2483 } 2484 2485 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, 2486 ValueRange newValues) { 2487 ValueScope &valueScope = valueScopes.back(); 2488 std::vector<Value> &values = valueScope.values; 2489 2490 unsigned &valueID = valueScope.nextValueIDs.back(); 2491 unsigned valueIDEnd = valueID + newValues.size(); 2492 if (valueIDEnd > values.size()) { 2493 return reader.emitError( 2494 "value index range was outside of the expected range for " 2495 "the parent region, got [", 2496 valueID, ", ", valueIDEnd, "), but the maximum index was ", 2497 values.size() - 1); 2498 } 2499 2500 // Assign the values and update any forward references. 2501 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { 2502 Value newValue = newValues[i]; 2503 2504 // Check to see if a definition for this value already exists. 2505 if (Value oldValue = std::exchange(values[valueID], newValue)) { 2506 Operation *forwardRefOp = oldValue.getDefiningOp(); 2507 2508 // Assert that this is a forward reference operation. Given how we compute 2509 // definition ids (incrementally as we parse), it shouldn't be possible 2510 // for the value to be defined any other way. 2511 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && 2512 "value index was already defined?"); 2513 2514 oldValue.replaceAllUsesWith(newValue); 2515 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); 2516 } 2517 } 2518 return success(); 2519 } 2520 2521 Value BytecodeReader::Impl::createForwardRef() { 2522 // Check for an avaliable existing operation to use. Otherwise, create a new 2523 // fake operation to use for the reference. 2524 if (!openForwardRefOps.empty()) { 2525 Operation *op = &openForwardRefOps.back(); 2526 op->moveBefore(&forwardRefOps, forwardRefOps.end()); 2527 } else { 2528 forwardRefOps.push_back(Operation::create(forwardRefOpState)); 2529 } 2530 return forwardRefOps.back().getResult(0); 2531 } 2532 2533 //===----------------------------------------------------------------------===// 2534 // Entry Points 2535 //===----------------------------------------------------------------------===// 2536 2537 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } 2538 2539 BytecodeReader::BytecodeReader( 2540 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, 2541 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2542 Location sourceFileLoc = 2543 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2544 /*line=*/0, /*column=*/0); 2545 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, 2546 bufferOwnerRef); 2547 } 2548 2549 LogicalResult BytecodeReader::readTopLevel( 2550 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2551 return impl->read(block, lazyOpsCallback); 2552 } 2553 2554 int64_t BytecodeReader::getNumOpsToMaterialize() const { 2555 return impl->getNumOpsToMaterialize(); 2556 } 2557 2558 bool BytecodeReader::isMaterializable(Operation *op) { 2559 return impl->isMaterializable(op); 2560 } 2561 2562 LogicalResult BytecodeReader::materialize( 2563 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { 2564 return impl->materialize(op, lazyOpsCallback); 2565 } 2566 2567 LogicalResult 2568 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { 2569 return impl->finalize(shouldMaterialize); 2570 } 2571 2572 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { 2573 return buffer.getBuffer().startswith("ML\xefR"); 2574 } 2575 2576 /// Read the bytecode from the provided memory buffer reference. 2577 /// `bufferOwnerRef` if provided is the owning source manager for the buffer, 2578 /// and may be used to extend the lifetime of the buffer. 2579 static LogicalResult 2580 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, 2581 const ParserConfig &config, 2582 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { 2583 Location sourceFileLoc = 2584 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), 2585 /*line=*/0, /*column=*/0); 2586 if (!isBytecode(buffer)) { 2587 return emitError(sourceFileLoc, 2588 "input buffer is not an MLIR bytecode file"); 2589 } 2590 2591 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, 2592 buffer, bufferOwnerRef); 2593 return reader.read(block, /*lazyOpsCallback=*/nullptr); 2594 } 2595 2596 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, 2597 const ParserConfig &config) { 2598 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); 2599 } 2600 LogicalResult 2601 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, 2602 Block *block, const ParserConfig &config) { 2603 return readBytecodeFileImpl( 2604 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, 2605 sourceMgr); 2606 } 2607