xref: /llvm-project/mlir/lib/Bytecode/Reader/BytecodeReader.cpp (revision 1c8c365de2ac59198e6ef9d61a5557b846ce0bca)
1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Bytecode/BytecodeReader.h"
10 #include "mlir/AsmParser/AsmParser.h"
11 #include "mlir/Bytecode/BytecodeImplementation.h"
12 #include "mlir/Bytecode/BytecodeOpInterface.h"
13 #include "mlir/Bytecode/Encoding.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Support/LogicalResult.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Endian.h"
26 #include "llvm/Support/MemoryBufferRef.h"
27 #include "llvm/Support/SourceMgr.h"
28 
29 #include <cstddef>
30 #include <list>
31 #include <memory>
32 #include <numeric>
33 #include <optional>
34 
35 #define DEBUG_TYPE "mlir-bytecode-reader"
36 
37 using namespace mlir;
38 
39 /// Stringify the given section ID.
40 static std::string toString(bytecode::Section::ID sectionID) {
41   switch (sectionID) {
42   case bytecode::Section::kString:
43     return "String (0)";
44   case bytecode::Section::kDialect:
45     return "Dialect (1)";
46   case bytecode::Section::kAttrType:
47     return "AttrType (2)";
48   case bytecode::Section::kAttrTypeOffset:
49     return "AttrTypeOffset (3)";
50   case bytecode::Section::kIR:
51     return "IR (4)";
52   case bytecode::Section::kResource:
53     return "Resource (5)";
54   case bytecode::Section::kResourceOffset:
55     return "ResourceOffset (6)";
56   case bytecode::Section::kDialectVersions:
57     return "DialectVersions (7)";
58   case bytecode::Section::kProperties:
59     return "Properties (8)";
60   default:
61     return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
62   }
63 }
64 
65 /// Returns true if the given top-level section ID is optional.
66 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
67   switch (sectionID) {
68   case bytecode::Section::kString:
69   case bytecode::Section::kDialect:
70   case bytecode::Section::kAttrType:
71   case bytecode::Section::kAttrTypeOffset:
72   case bytecode::Section::kIR:
73     return false;
74   case bytecode::Section::kResource:
75   case bytecode::Section::kResourceOffset:
76   case bytecode::Section::kDialectVersions:
77     return true;
78   case bytecode::Section::kProperties:
79     return version < bytecode::kNativePropertiesEncoding;
80   default:
81     llvm_unreachable("unknown section ID");
82   }
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // EncodingReader
87 //===----------------------------------------------------------------------===//
88 
89 namespace {
90 class EncodingReader {
91 public:
92   explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
93       : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
94   explicit EncodingReader(StringRef contents, Location fileLoc)
95       : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
96                         contents.size()},
97                        fileLoc) {}
98 
99   /// Returns true if the entire section has been read.
100   bool empty() const { return dataIt == buffer.end(); }
101 
102   /// Returns the remaining size of the bytecode.
103   size_t size() const { return buffer.end() - dataIt; }
104 
105   /// Align the current reader position to the specified alignment.
106   LogicalResult alignTo(unsigned alignment) {
107     if (!llvm::isPowerOf2_32(alignment))
108       return emitError("expected alignment to be a power-of-two");
109 
110     auto isUnaligned = [&](const uint8_t *ptr) {
111       return ((uintptr_t)ptr & (alignment - 1)) != 0;
112     };
113 
114     // Ensure the data buffer was sufficiently aligned in the first place.
115     if (LLVM_UNLIKELY(isUnaligned(buffer.begin()))) {
116       return emitError("expected bytecode buffer to be aligned to ", alignment,
117                        ", but got pointer: '0x" +
118                            llvm::utohexstr((uintptr_t)buffer.begin()) + "'");
119     }
120 
121     // Shift the reader position to the next alignment boundary.
122     while (isUnaligned(dataIt)) {
123       uint8_t padding;
124       if (failed(parseByte(padding)))
125         return failure();
126       if (padding != bytecode::kAlignmentByte) {
127         return emitError("expected alignment byte (0xCB), but got: '0x" +
128                          llvm::utohexstr(padding) + "'");
129       }
130     }
131 
132     // Ensure the data iterator is now aligned. This case is unlikely because we
133     // *just* went through the effort to align the data iterator.
134     if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
135       return emitError("expected data iterator aligned to ", alignment,
136                        ", but got pointer: '0x" +
137                            llvm::utohexstr((uintptr_t)dataIt) + "'");
138     }
139 
140     return success();
141   }
142 
143   /// Emit an error using the given arguments.
144   template <typename... Args>
145   InFlightDiagnostic emitError(Args &&...args) const {
146     return ::emitError(fileLoc).append(std::forward<Args>(args)...);
147   }
148   InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
149 
150   /// Parse a single byte from the stream.
151   template <typename T>
152   LogicalResult parseByte(T &value) {
153     if (empty())
154       return emitError("attempting to parse a byte at the end of the bytecode");
155     value = static_cast<T>(*dataIt++);
156     return success();
157   }
158   /// Parse a range of bytes of 'length' into the given result.
159   LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
160     if (length > size()) {
161       return emitError("attempting to parse ", length, " bytes when only ",
162                        size(), " remain");
163     }
164     result = {dataIt, length};
165     dataIt += length;
166     return success();
167   }
168   /// Parse a range of bytes of 'length' into the given result, which can be
169   /// assumed to be large enough to hold `length`.
170   LogicalResult parseBytes(size_t length, uint8_t *result) {
171     if (length > size()) {
172       return emitError("attempting to parse ", length, " bytes when only ",
173                        size(), " remain");
174     }
175     memcpy(result, dataIt, length);
176     dataIt += length;
177     return success();
178   }
179 
180   /// Parse an aligned blob of data, where the alignment was encoded alongside
181   /// the data.
182   LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
183                                       uint64_t &alignment) {
184     uint64_t dataSize;
185     if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
186         failed(alignTo(alignment)))
187       return failure();
188     return parseBytes(dataSize, data);
189   }
190 
191   /// Parse a variable length encoded integer from the byte stream. The first
192   /// encoded byte contains a prefix in the low bits indicating the encoded
193   /// length of the value. This length prefix is a bit sequence of '0's followed
194   /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
195   /// (not including the prefix byte). All remaining bits in the first byte,
196   /// along with all of the bits in additional bytes, provide the value of the
197   /// integer encoded in little-endian order.
198   LogicalResult parseVarInt(uint64_t &result) {
199     // Parse the first byte of the encoding, which contains the length prefix.
200     if (failed(parseByte(result)))
201       return failure();
202 
203     // Handle the overwhelmingly common case where the value is stored in a
204     // single byte. In this case, the first bit is the `1` marker bit.
205     if (LLVM_LIKELY(result & 1)) {
206       result >>= 1;
207       return success();
208     }
209 
210     // Handle the overwhelming uncommon case where the value required all 8
211     // bytes (i.e. a really really big number). In this case, the marker byte is
212     // all zeros: `00000000`.
213     if (LLVM_UNLIKELY(result == 0)) {
214       llvm::support::ulittle64_t resultLE;
215       if (failed(parseBytes(sizeof(resultLE),
216                             reinterpret_cast<uint8_t *>(&resultLE))))
217         return failure();
218       result = resultLE;
219       return success();
220     }
221     return parseMultiByteVarInt(result);
222   }
223 
224   /// Parse a signed variable length encoded integer from the byte stream. A
225   /// signed varint is encoded as a normal varint with zigzag encoding applied,
226   /// i.e. the low bit of the value is used to indicate the sign.
227   LogicalResult parseSignedVarInt(uint64_t &result) {
228     if (failed(parseVarInt(result)))
229       return failure();
230     // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
231     result = (result >> 1) ^ (~(result & 1) + 1);
232     return success();
233   }
234 
235   /// Parse a variable length encoded integer whose low bit is used to encode an
236   /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
237   LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
238     if (failed(parseVarInt(result)))
239       return failure();
240     flag = result & 1;
241     result >>= 1;
242     return success();
243   }
244 
245   /// Skip the first `length` bytes within the reader.
246   LogicalResult skipBytes(size_t length) {
247     if (length > size()) {
248       return emitError("attempting to skip ", length, " bytes when only ",
249                        size(), " remain");
250     }
251     dataIt += length;
252     return success();
253   }
254 
255   /// Parse a null-terminated string into `result` (without including the NUL
256   /// terminator).
257   LogicalResult parseNullTerminatedString(StringRef &result) {
258     const char *startIt = (const char *)dataIt;
259     const char *nulIt = (const char *)memchr(startIt, 0, size());
260     if (!nulIt)
261       return emitError(
262           "malformed null-terminated string, no null character found");
263 
264     result = StringRef(startIt, nulIt - startIt);
265     dataIt = (const uint8_t *)nulIt + 1;
266     return success();
267   }
268 
269   /// Parse a section header, placing the kind of section in `sectionID` and the
270   /// contents of the section in `sectionData`.
271   LogicalResult parseSection(bytecode::Section::ID &sectionID,
272                              ArrayRef<uint8_t> &sectionData) {
273     uint8_t sectionIDAndHasAlignment;
274     uint64_t length;
275     if (failed(parseByte(sectionIDAndHasAlignment)) ||
276         failed(parseVarInt(length)))
277       return failure();
278 
279     // Extract the section ID and whether the section is aligned. The high bit
280     // of the ID is the alignment flag.
281     sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
282                                                    0b01111111);
283     bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
284 
285     // Check that the section is actually valid before trying to process its
286     // data.
287     if (sectionID >= bytecode::Section::kNumSections)
288       return emitError("invalid section ID: ", unsigned(sectionID));
289 
290     // Process the section alignment if present.
291     if (hasAlignment) {
292       uint64_t alignment;
293       if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
294         return failure();
295     }
296 
297     // Parse the actual section data.
298     return parseBytes(static_cast<size_t>(length), sectionData);
299   }
300 
301   Location getLoc() const { return fileLoc; }
302 
303 private:
304   /// Parse a variable length encoded integer from the byte stream. This method
305   /// is a fallback when the number of bytes used to encode the value is greater
306   /// than 1, but less than the max (9). The provided `result` value can be
307   /// assumed to already contain the first byte of the value.
308   /// NOTE: This method is marked noinline to avoid pessimizing the common case
309   /// of single byte encoding.
310   LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
311     // Count the number of trailing zeros in the marker byte, this indicates the
312     // number of trailing bytes that are part of the value. We use `uint32_t`
313     // here because we only care about the first byte, and so that be actually
314     // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
315     // implementation).
316     uint32_t numBytes = llvm::countr_zero<uint32_t>(result);
317     assert(numBytes > 0 && numBytes <= 7 &&
318            "unexpected number of trailing zeros in varint encoding");
319 
320     // Parse in the remaining bytes of the value.
321     llvm::support::ulittle64_t resultLE(result);
322     if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1)))
323       return failure();
324 
325     // Shift out the low-order bits that were used to mark how the value was
326     // encoded.
327     result = resultLE >> (numBytes + 1);
328     return success();
329   }
330 
331   /// The bytecode buffer.
332   ArrayRef<uint8_t> buffer;
333 
334   /// The current iterator within the 'buffer'.
335   const uint8_t *dataIt;
336 
337   /// A location for the bytecode used to report errors.
338   Location fileLoc;
339 };
340 } // namespace
341 
342 /// Resolve an index into the given entry list. `entry` may either be a
343 /// reference, in which case it is assigned to the corresponding value in
344 /// `entries`, or a pointer, in which case it is assigned to the address of the
345 /// element in `entries`.
346 template <typename RangeT, typename T>
347 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
348                                   uint64_t index, T &entry,
349                                   StringRef entryStr) {
350   if (index >= entries.size())
351     return reader.emitError("invalid ", entryStr, " index: ", index);
352 
353   // If the provided entry is a pointer, resolve to the address of the entry.
354   if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
355     entry = entries[index];
356   else
357     entry = &entries[index];
358   return success();
359 }
360 
361 /// Parse and resolve an index into the given entry list.
362 template <typename RangeT, typename T>
363 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
364                                 T &entry, StringRef entryStr) {
365   uint64_t entryIdx;
366   if (failed(reader.parseVarInt(entryIdx)))
367     return failure();
368   return resolveEntry(reader, entries, entryIdx, entry, entryStr);
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // StringSectionReader
373 //===----------------------------------------------------------------------===//
374 
375 namespace {
376 /// This class is used to read references to the string section from the
377 /// bytecode.
378 class StringSectionReader {
379 public:
380   /// Initialize the string section reader with the given section data.
381   LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
382 
383   /// Parse a shared string from the string section. The shared string is
384   /// encoded using an index to a corresponding string in the string section.
385   LogicalResult parseString(EncodingReader &reader, StringRef &result) {
386     return parseEntry(reader, strings, result, "string");
387   }
388 
389   /// Parse a shared string from the string section. The shared string is
390   /// encoded using an index to a corresponding string in the string section.
391   /// This variant parses a flag compressed with the index.
392   LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
393                                     bool &flag) {
394     uint64_t entryIdx;
395     if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
396       return failure();
397     return parseStringAtIndex(reader, entryIdx, result);
398   }
399 
400   /// Parse a shared string from the string section. The shared string is
401   /// encoded using an index to a corresponding string in the string section.
402   LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
403                                    StringRef &result) {
404     return resolveEntry(reader, strings, index, result, "string");
405   }
406 
407 private:
408   /// The table of strings referenced within the bytecode file.
409   SmallVector<StringRef> strings;
410 };
411 } // namespace
412 
413 LogicalResult StringSectionReader::initialize(Location fileLoc,
414                                               ArrayRef<uint8_t> sectionData) {
415   EncodingReader stringReader(sectionData, fileLoc);
416 
417   // Parse the number of strings in the section.
418   uint64_t numStrings;
419   if (failed(stringReader.parseVarInt(numStrings)))
420     return failure();
421   strings.resize(numStrings);
422 
423   // Parse each of the strings. The sizes of the strings are encoded in reverse
424   // order, so that's the order we populate the table.
425   size_t stringDataEndOffset = sectionData.size();
426   for (StringRef &string : llvm::reverse(strings)) {
427     uint64_t stringSize;
428     if (failed(stringReader.parseVarInt(stringSize)))
429       return failure();
430     if (stringDataEndOffset < stringSize) {
431       return stringReader.emitError(
432           "string size exceeds the available data size");
433     }
434 
435     // Extract the string from the data, dropping the null character.
436     size_t stringOffset = stringDataEndOffset - stringSize;
437     string = StringRef(
438         reinterpret_cast<const char *>(sectionData.data() + stringOffset),
439         stringSize - 1);
440     stringDataEndOffset = stringOffset;
441   }
442 
443   // Check that the only remaining data was for the strings, i.e. the reader
444   // should be at the same offset as the first string.
445   if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
446     return stringReader.emitError("unexpected trailing data between the "
447                                   "offsets for strings and their data");
448   }
449   return success();
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // BytecodeDialect
454 //===----------------------------------------------------------------------===//
455 
456 namespace {
457 class DialectReader;
458 
459 /// This struct represents a dialect entry within the bytecode.
460 struct BytecodeDialect {
461   /// Load the dialect into the provided context if it hasn't been loaded yet.
462   /// Returns failure if the dialect couldn't be loaded *and* the provided
463   /// context does not allow unregistered dialects. The provided reader is used
464   /// for error emission if necessary.
465   LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
466 
467   /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
468   /// only be called after `load`.
469   Dialect *getLoadedDialect() const {
470     assert(dialect &&
471            "expected `load` to be invoked before `getLoadedDialect`");
472     return *dialect;
473   }
474 
475   /// The loaded dialect entry. This field is std::nullopt if we haven't
476   /// attempted to load, nullptr if we failed to load, otherwise the loaded
477   /// dialect.
478   std::optional<Dialect *> dialect;
479 
480   /// The bytecode interface of the dialect, or nullptr if the dialect does not
481   /// implement the bytecode interface. This field should only be checked if the
482   /// `dialect` field is not std::nullopt.
483   const BytecodeDialectInterface *interface = nullptr;
484 
485   /// The name of the dialect.
486   StringRef name;
487 
488   /// A buffer containing the encoding of the dialect version parsed.
489   ArrayRef<uint8_t> versionBuffer;
490 
491   /// Lazy loaded dialect version from the handle above.
492   std::unique_ptr<DialectVersion> loadedVersion;
493 };
494 
495 /// This struct represents an operation name entry within the bytecode.
496 struct BytecodeOperationName {
497   BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
498                         std::optional<bool> wasRegistered)
499       : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
500 
501   /// The loaded operation name, or std::nullopt if it hasn't been processed
502   /// yet.
503   std::optional<OperationName> opName;
504 
505   /// The dialect that owns this operation name.
506   BytecodeDialect *dialect;
507 
508   /// The name of the operation, without the dialect prefix.
509   StringRef name;
510 
511   /// Whether this operation was registered when the bytecode was produced.
512   /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
513   std::optional<bool> wasRegistered;
514 };
515 } // namespace
516 
517 /// Parse a single dialect group encoded in the byte stream.
518 static LogicalResult parseDialectGrouping(
519     EncodingReader &reader,
520     MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
521     function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
522   // Parse the dialect and the number of entries in the group.
523   std::unique_ptr<BytecodeDialect> *dialect;
524   if (failed(parseEntry(reader, dialects, dialect, "dialect")))
525     return failure();
526   uint64_t numEntries;
527   if (failed(reader.parseVarInt(numEntries)))
528     return failure();
529 
530   for (uint64_t i = 0; i < numEntries; ++i)
531     if (failed(entryCallback(dialect->get())))
532       return failure();
533   return success();
534 }
535 
536 //===----------------------------------------------------------------------===//
537 // ResourceSectionReader
538 //===----------------------------------------------------------------------===//
539 
540 namespace {
541 /// This class is used to read the resource section from the bytecode.
542 class ResourceSectionReader {
543 public:
544   /// Initialize the resource section reader with the given section data.
545   LogicalResult
546   initialize(Location fileLoc, const ParserConfig &config,
547              MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
548              StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
549              ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
550              const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
551 
552   /// Parse a dialect resource handle from the resource section.
553   LogicalResult parseResourceHandle(EncodingReader &reader,
554                                     AsmDialectResourceHandle &result) {
555     return parseEntry(reader, dialectResources, result, "resource handle");
556   }
557 
558 private:
559   /// The table of dialect resources within the bytecode file.
560   SmallVector<AsmDialectResourceHandle> dialectResources;
561   llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
562 };
563 
564 class ParsedResourceEntry : public AsmParsedResourceEntry {
565 public:
566   ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
567                       EncodingReader &reader, StringSectionReader &stringReader,
568                       const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
569       : key(key), kind(kind), reader(reader), stringReader(stringReader),
570         bufferOwnerRef(bufferOwnerRef) {}
571   ~ParsedResourceEntry() override = default;
572 
573   StringRef getKey() const final { return key; }
574 
575   InFlightDiagnostic emitError() const final { return reader.emitError(); }
576 
577   AsmResourceEntryKind getKind() const final { return kind; }
578 
579   FailureOr<bool> parseAsBool() const final {
580     if (kind != AsmResourceEntryKind::Bool)
581       return emitError() << "expected a bool resource entry, but found a "
582                          << toString(kind) << " entry instead";
583 
584     bool value;
585     if (failed(reader.parseByte(value)))
586       return failure();
587     return value;
588   }
589   FailureOr<std::string> parseAsString() const final {
590     if (kind != AsmResourceEntryKind::String)
591       return emitError() << "expected a string resource entry, but found a "
592                          << toString(kind) << " entry instead";
593 
594     StringRef string;
595     if (failed(stringReader.parseString(reader, string)))
596       return failure();
597     return string.str();
598   }
599 
600   FailureOr<AsmResourceBlob>
601   parseAsBlob(BlobAllocatorFn allocator) const final {
602     if (kind != AsmResourceEntryKind::Blob)
603       return emitError() << "expected a blob resource entry, but found a "
604                          << toString(kind) << " entry instead";
605 
606     ArrayRef<uint8_t> data;
607     uint64_t alignment;
608     if (failed(reader.parseBlobAndAlignment(data, alignment)))
609       return failure();
610 
611     // If we have an extendable reference to the buffer owner, we don't need to
612     // allocate a new buffer for the data, and can use the data directly.
613     if (bufferOwnerRef) {
614       ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
615                               data.size());
616 
617       // Allocate an unmanager buffer which captures a reference to the owner.
618       // For now we just mark this as immutable, but in the future we should
619       // explore marking this as mutable when desired.
620       return UnmanagedAsmResourceBlob::allocateWithAlign(
621           charData, alignment,
622           [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
623     }
624 
625     // Allocate memory for the blob using the provided allocator and copy the
626     // data into it.
627     AsmResourceBlob blob = allocator(data.size(), alignment);
628     assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
629            blob.isMutable() &&
630            "blob allocator did not return a properly aligned address");
631     memcpy(blob.getMutableData().data(), data.data(), data.size());
632     return blob;
633   }
634 
635 private:
636   StringRef key;
637   AsmResourceEntryKind kind;
638   EncodingReader &reader;
639   StringSectionReader &stringReader;
640   const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
641 };
642 } // namespace
643 
644 template <typename T>
645 static LogicalResult
646 parseResourceGroup(Location fileLoc, bool allowEmpty,
647                    EncodingReader &offsetReader, EncodingReader &resourceReader,
648                    StringSectionReader &stringReader, T *handler,
649                    const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
650                    function_ref<StringRef(StringRef)> remapKey = {},
651                    function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
652   uint64_t numResources;
653   if (failed(offsetReader.parseVarInt(numResources)))
654     return failure();
655 
656   for (uint64_t i = 0; i < numResources; ++i) {
657     StringRef key;
658     AsmResourceEntryKind kind;
659     uint64_t resourceOffset;
660     ArrayRef<uint8_t> data;
661     if (failed(stringReader.parseString(offsetReader, key)) ||
662         failed(offsetReader.parseVarInt(resourceOffset)) ||
663         failed(offsetReader.parseByte(kind)) ||
664         failed(resourceReader.parseBytes(resourceOffset, data)))
665       return failure();
666 
667     // Process the resource key.
668     if ((processKeyFn && failed(processKeyFn(key))))
669       return failure();
670 
671     // If the resource data is empty and we allow it, don't error out when
672     // parsing below, just skip it.
673     if (allowEmpty && data.empty())
674       continue;
675 
676     // Ignore the entry if we don't have a valid handler.
677     if (!handler)
678       continue;
679 
680     // Otherwise, parse the resource value.
681     EncodingReader entryReader(data, fileLoc);
682     key = remapKey(key);
683     ParsedResourceEntry entry(key, kind, entryReader, stringReader,
684                               bufferOwnerRef);
685     if (failed(handler->parseResource(entry)))
686       return failure();
687     if (!entryReader.empty()) {
688       return entryReader.emitError(
689           "unexpected trailing bytes in resource entry '", key, "'");
690     }
691   }
692   return success();
693 }
694 
695 LogicalResult ResourceSectionReader::initialize(
696     Location fileLoc, const ParserConfig &config,
697     MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
698     StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
699     ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
700     const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
701   EncodingReader resourceReader(sectionData, fileLoc);
702   EncodingReader offsetReader(offsetSectionData, fileLoc);
703 
704   // Read the number of external resource providers.
705   uint64_t numExternalResourceGroups;
706   if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
707     return failure();
708 
709   // Utility functor that dispatches to `parseResourceGroup`, but implicitly
710   // provides most of the arguments.
711   auto parseGroup = [&](auto *handler, bool allowEmpty = false,
712                         function_ref<LogicalResult(StringRef)> keyFn = {}) {
713     auto resolveKey = [&](StringRef key) -> StringRef {
714       auto it = dialectResourceHandleRenamingMap.find(key);
715       if (it == dialectResourceHandleRenamingMap.end())
716         return "";
717       return it->second;
718     };
719 
720     return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
721                               stringReader, handler, bufferOwnerRef, resolveKey,
722                               keyFn);
723   };
724 
725   // Read the external resources from the bytecode.
726   for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
727     StringRef key;
728     if (failed(stringReader.parseString(offsetReader, key)))
729       return failure();
730 
731     // Get the handler for these resources.
732     // TODO: Should we require handling external resources in some scenarios?
733     AsmResourceParser *handler = config.getResourceParser(key);
734     if (!handler) {
735       emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
736                            << "'";
737     }
738 
739     if (failed(parseGroup(handler)))
740       return failure();
741   }
742 
743   // Read the dialect resources from the bytecode.
744   MLIRContext *ctx = fileLoc->getContext();
745   while (!offsetReader.empty()) {
746     std::unique_ptr<BytecodeDialect> *dialect;
747     if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
748         failed((*dialect)->load(dialectReader, ctx)))
749       return failure();
750     Dialect *loadedDialect = (*dialect)->getLoadedDialect();
751     if (!loadedDialect) {
752       return resourceReader.emitError()
753              << "dialect '" << (*dialect)->name << "' is unknown";
754     }
755     const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
756     if (!handler) {
757       return resourceReader.emitError()
758              << "unexpected resources for dialect '" << (*dialect)->name << "'";
759     }
760 
761     // Ensure that each resource is declared before being processed.
762     auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
763       FailureOr<AsmDialectResourceHandle> handle =
764           handler->declareResource(key);
765       if (failed(handle)) {
766         return resourceReader.emitError()
767                << "unknown 'resource' key '" << key << "' for dialect '"
768                << (*dialect)->name << "'";
769       }
770       dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
771       dialectResources.push_back(*handle);
772       return success();
773     };
774 
775     // Parse the resources for this dialect. We allow empty resources because we
776     // just treat these as declarations.
777     if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
778       return failure();
779   }
780 
781   return success();
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // Attribute/Type Reader
786 //===----------------------------------------------------------------------===//
787 
788 namespace {
789 /// This class provides support for reading attribute and type entries from the
790 /// bytecode. Attribute and Type entries are read lazily on demand, so we use
791 /// this reader to manage when to actually parse them from the bytecode.
792 class AttrTypeReader {
793   /// This class represents a single attribute or type entry.
794   template <typename T>
795   struct Entry {
796     /// The entry, or null if it hasn't been resolved yet.
797     T entry = {};
798     /// The parent dialect of this entry.
799     BytecodeDialect *dialect = nullptr;
800     /// A flag indicating if the entry was encoded using a custom encoding,
801     /// instead of using the textual assembly format.
802     bool hasCustomEncoding = false;
803     /// The raw data of this entry in the bytecode.
804     ArrayRef<uint8_t> data;
805   };
806   using AttrEntry = Entry<Attribute>;
807   using TypeEntry = Entry<Type>;
808 
809 public:
810   AttrTypeReader(StringSectionReader &stringReader,
811                  ResourceSectionReader &resourceReader,
812                  const llvm::StringMap<BytecodeDialect *> &dialectsMap,
813                  uint64_t &bytecodeVersion, Location fileLoc,
814                  const ParserConfig &config)
815       : stringReader(stringReader), resourceReader(resourceReader),
816         dialectsMap(dialectsMap), fileLoc(fileLoc),
817         bytecodeVersion(bytecodeVersion), parserConfig(config) {}
818 
819   /// Initialize the attribute and type information within the reader.
820   LogicalResult
821   initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
822              ArrayRef<uint8_t> sectionData,
823              ArrayRef<uint8_t> offsetSectionData);
824 
825   /// Resolve the attribute or type at the given index. Returns nullptr on
826   /// failure.
827   Attribute resolveAttribute(size_t index) {
828     return resolveEntry(attributes, index, "Attribute");
829   }
830   Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
831 
832   /// Parse a reference to an attribute or type using the given reader.
833   LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
834     uint64_t attrIdx;
835     if (failed(reader.parseVarInt(attrIdx)))
836       return failure();
837     result = resolveAttribute(attrIdx);
838     return success(!!result);
839   }
840   LogicalResult parseOptionalAttribute(EncodingReader &reader,
841                                        Attribute &result) {
842     uint64_t attrIdx;
843     bool flag;
844     if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
845       return failure();
846     if (!flag)
847       return success();
848     result = resolveAttribute(attrIdx);
849     return success(!!result);
850   }
851 
852   LogicalResult parseType(EncodingReader &reader, Type &result) {
853     uint64_t typeIdx;
854     if (failed(reader.parseVarInt(typeIdx)))
855       return failure();
856     result = resolveType(typeIdx);
857     return success(!!result);
858   }
859 
860   template <typename T>
861   LogicalResult parseAttribute(EncodingReader &reader, T &result) {
862     Attribute baseResult;
863     if (failed(parseAttribute(reader, baseResult)))
864       return failure();
865     if ((result = dyn_cast<T>(baseResult)))
866       return success();
867     return reader.emitError("expected attribute of type: ",
868                             llvm::getTypeName<T>(), ", but got: ", baseResult);
869   }
870 
871 private:
872   /// Resolve the given entry at `index`.
873   template <typename T>
874   T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
875                  StringRef entryType);
876 
877   /// Parse an entry using the given reader that was encoded using the textual
878   /// assembly format.
879   template <typename T>
880   LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
881                               StringRef entryType);
882 
883   /// Parse an entry using the given reader that was encoded using a custom
884   /// bytecode format.
885   template <typename T>
886   LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
887                                  StringRef entryType);
888 
889   /// The string section reader used to resolve string references when parsing
890   /// custom encoded attribute/type entries.
891   StringSectionReader &stringReader;
892 
893   /// The resource section reader used to resolve resource references when
894   /// parsing custom encoded attribute/type entries.
895   ResourceSectionReader &resourceReader;
896 
897   /// The map of the loaded dialects used to retrieve dialect information, such
898   /// as the dialect version.
899   const llvm::StringMap<BytecodeDialect *> &dialectsMap;
900 
901   /// The set of attribute and type entries.
902   SmallVector<AttrEntry> attributes;
903   SmallVector<TypeEntry> types;
904 
905   /// A location used for error emission.
906   Location fileLoc;
907 
908   /// Current bytecode version being used.
909   uint64_t &bytecodeVersion;
910 
911   /// Reference to the parser configuration.
912   const ParserConfig &parserConfig;
913 };
914 
915 class DialectReader : public DialectBytecodeReader {
916 public:
917   DialectReader(AttrTypeReader &attrTypeReader,
918                 StringSectionReader &stringReader,
919                 ResourceSectionReader &resourceReader,
920                 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
921                 EncodingReader &reader, uint64_t &bytecodeVersion)
922       : attrTypeReader(attrTypeReader), stringReader(stringReader),
923         resourceReader(resourceReader), dialectsMap(dialectsMap),
924         reader(reader), bytecodeVersion(bytecodeVersion) {}
925 
926   InFlightDiagnostic emitError(const Twine &msg) const override {
927     return reader.emitError(msg);
928   }
929 
930   FailureOr<const DialectVersion *>
931   getDialectVersion(StringRef dialectName) const override {
932     // First check if the dialect is available in the map.
933     auto dialectEntry = dialectsMap.find(dialectName);
934     if (dialectEntry == dialectsMap.end())
935       return failure();
936     // If the dialect was found, try to load it. This will trigger reading the
937     // bytecode version from the version buffer if it wasn't already processed.
938     // Return failure if either of those two actions could not be completed.
939     if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
940         dialectEntry->getValue()->loadedVersion == nullptr)
941       return failure();
942     return dialectEntry->getValue()->loadedVersion.get();
943   }
944 
945   MLIRContext *getContext() const override { return getLoc().getContext(); }
946 
947   uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
948 
949   DialectReader withEncodingReader(EncodingReader &encReader) const {
950     return DialectReader(attrTypeReader, stringReader, resourceReader,
951                          dialectsMap, encReader, bytecodeVersion);
952   }
953 
954   Location getLoc() const { return reader.getLoc(); }
955 
956   //===--------------------------------------------------------------------===//
957   // IR
958   //===--------------------------------------------------------------------===//
959 
960   LogicalResult readAttribute(Attribute &result) override {
961     return attrTypeReader.parseAttribute(reader, result);
962   }
963   LogicalResult readOptionalAttribute(Attribute &result) override {
964     return attrTypeReader.parseOptionalAttribute(reader, result);
965   }
966   LogicalResult readType(Type &result) override {
967     return attrTypeReader.parseType(reader, result);
968   }
969 
970   FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
971     AsmDialectResourceHandle handle;
972     if (failed(resourceReader.parseResourceHandle(reader, handle)))
973       return failure();
974     return handle;
975   }
976 
977   //===--------------------------------------------------------------------===//
978   // Primitives
979   //===--------------------------------------------------------------------===//
980 
981   LogicalResult readVarInt(uint64_t &result) override {
982     return reader.parseVarInt(result);
983   }
984 
985   LogicalResult readSignedVarInt(int64_t &result) override {
986     uint64_t unsignedResult;
987     if (failed(reader.parseSignedVarInt(unsignedResult)))
988       return failure();
989     result = static_cast<int64_t>(unsignedResult);
990     return success();
991   }
992 
993   FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
994     // Small values are encoded using a single byte.
995     if (bitWidth <= 8) {
996       uint8_t value;
997       if (failed(reader.parseByte(value)))
998         return failure();
999       return APInt(bitWidth, value);
1000     }
1001 
1002     // Large values up to 64 bits are encoded using a single varint.
1003     if (bitWidth <= 64) {
1004       uint64_t value;
1005       if (failed(reader.parseSignedVarInt(value)))
1006         return failure();
1007       return APInt(bitWidth, value);
1008     }
1009 
1010     // Otherwise, for really big values we encode the array of active words in
1011     // the value.
1012     uint64_t numActiveWords;
1013     if (failed(reader.parseVarInt(numActiveWords)))
1014       return failure();
1015     SmallVector<uint64_t, 4> words(numActiveWords);
1016     for (uint64_t i = 0; i < numActiveWords; ++i)
1017       if (failed(reader.parseSignedVarInt(words[i])))
1018         return failure();
1019     return APInt(bitWidth, words);
1020   }
1021 
1022   FailureOr<APFloat>
1023   readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
1024     FailureOr<APInt> intVal =
1025         readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1026     if (failed(intVal))
1027       return failure();
1028     return APFloat(semantics, *intVal);
1029   }
1030 
1031   LogicalResult readString(StringRef &result) override {
1032     return stringReader.parseString(reader, result);
1033   }
1034 
1035   LogicalResult readBlob(ArrayRef<char> &result) override {
1036     uint64_t dataSize;
1037     ArrayRef<uint8_t> data;
1038     if (failed(reader.parseVarInt(dataSize)) ||
1039         failed(reader.parseBytes(dataSize, data)))
1040       return failure();
1041     result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
1042                             data.size());
1043     return success();
1044   }
1045 
1046   LogicalResult readBool(bool &result) override {
1047     return reader.parseByte(result);
1048   }
1049 
1050 private:
1051   AttrTypeReader &attrTypeReader;
1052   StringSectionReader &stringReader;
1053   ResourceSectionReader &resourceReader;
1054   const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1055   EncodingReader &reader;
1056   uint64_t &bytecodeVersion;
1057 };
1058 
1059 /// Wraps the properties section and handles reading properties out of it.
1060 class PropertiesSectionReader {
1061 public:
1062   /// Initialize the properties section reader with the given section data.
1063   LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1064     if (sectionData.empty())
1065       return success();
1066     EncodingReader propReader(sectionData, fileLoc);
1067     uint64_t count;
1068     if (failed(propReader.parseVarInt(count)))
1069       return failure();
1070     // Parse the raw properties buffer.
1071     if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1072       return failure();
1073 
1074     EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1075     offsetTable.reserve(count);
1076     for (auto idx : llvm::seq<int64_t>(0, count)) {
1077       (void)idx;
1078       offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1079       ArrayRef<uint8_t> rawProperties;
1080       uint64_t dataSize;
1081       if (failed(offsetsReader.parseVarInt(dataSize)) ||
1082           failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1083         return failure();
1084     }
1085     if (!offsetsReader.empty())
1086       return offsetsReader.emitError()
1087              << "Broken properties section: didn't exhaust the offsets table";
1088     return success();
1089   }
1090 
1091   LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1092                      OperationName *opName, OperationState &opState) {
1093     uint64_t propertiesIdx;
1094     if (failed(dialectReader.readVarInt(propertiesIdx)))
1095       return failure();
1096     if (propertiesIdx >= offsetTable.size())
1097       return dialectReader.emitError("Properties idx out-of-bound for ")
1098              << opName->getStringRef();
1099     size_t propertiesOffset = offsetTable[propertiesIdx];
1100     if (propertiesIdx >= propertiesBuffers.size())
1101       return dialectReader.emitError("Properties offset out-of-bound for ")
1102              << opName->getStringRef();
1103 
1104     // Acquire the sub-buffer that represent the requested properties.
1105     ArrayRef<char> rawProperties;
1106     {
1107       // "Seek" to the requested offset by getting a new reader with the right
1108       // sub-buffer.
1109       EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1110                             fileLoc);
1111       // Properties are stored as a sequence of {size + raw_data}.
1112       if (failed(
1113               dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1114         return failure();
1115     }
1116     // Setup a new reader to read from the `rawProperties` sub-buffer.
1117     EncodingReader reader(
1118         StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1119     DialectReader propReader = dialectReader.withEncodingReader(reader);
1120 
1121     auto *iface = opName->getInterface<BytecodeOpInterface>();
1122     if (iface)
1123       return iface->readProperties(propReader, opState);
1124     if (opName->isRegistered())
1125       return propReader.emitError(
1126                  "has properties but missing BytecodeOpInterface for ")
1127              << opName->getStringRef();
1128     // Unregistered op are storing properties as an attribute.
1129     return propReader.readAttribute(opState.propertiesAttr);
1130   }
1131 
1132 private:
1133   /// The properties buffer referenced within the bytecode file.
1134   ArrayRef<uint8_t> propertiesBuffers;
1135 
1136   /// Table of offset in the buffer above.
1137   SmallVector<int64_t> offsetTable;
1138 };
1139 } // namespace
1140 
1141 LogicalResult AttrTypeReader::initialize(
1142     MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1143     ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1144   EncodingReader offsetReader(offsetSectionData, fileLoc);
1145 
1146   // Parse the number of attribute and type entries.
1147   uint64_t numAttributes, numTypes;
1148   if (failed(offsetReader.parseVarInt(numAttributes)) ||
1149       failed(offsetReader.parseVarInt(numTypes)))
1150     return failure();
1151   attributes.resize(numAttributes);
1152   types.resize(numTypes);
1153 
1154   // A functor used to accumulate the offsets for the entries in the given
1155   // range.
1156   uint64_t currentOffset = 0;
1157   auto parseEntries = [&](auto &&range) {
1158     size_t currentIndex = 0, endIndex = range.size();
1159 
1160     // Parse an individual entry.
1161     auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1162       auto &entry = range[currentIndex++];
1163 
1164       uint64_t entrySize;
1165       if (failed(offsetReader.parseVarIntWithFlag(entrySize,
1166                                                   entry.hasCustomEncoding)))
1167         return failure();
1168 
1169       // Verify that the offset is actually valid.
1170       if (currentOffset + entrySize > sectionData.size()) {
1171         return offsetReader.emitError(
1172             "Attribute or Type entry offset points past the end of section");
1173       }
1174 
1175       entry.data = sectionData.slice(currentOffset, entrySize);
1176       entry.dialect = dialect;
1177       currentOffset += entrySize;
1178       return success();
1179     };
1180     while (currentIndex != endIndex)
1181       if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1182         return failure();
1183     return success();
1184   };
1185 
1186   // Process each of the attributes, and then the types.
1187   if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
1188     return failure();
1189 
1190   // Ensure that we read everything from the section.
1191   if (!offsetReader.empty()) {
1192     return offsetReader.emitError(
1193         "unexpected trailing data in the Attribute/Type offset section");
1194   }
1195 
1196   return success();
1197 }
1198 
1199 template <typename T>
1200 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1201                                StringRef entryType) {
1202   if (index >= entries.size()) {
1203     emitError(fileLoc) << "invalid " << entryType << " index: " << index;
1204     return {};
1205   }
1206 
1207   // If the entry has already been resolved, there is nothing left to do.
1208   Entry<T> &entry = entries[index];
1209   if (entry.entry)
1210     return entry.entry;
1211 
1212   // Parse the entry.
1213   EncodingReader reader(entry.data, fileLoc);
1214 
1215   // Parse based on how the entry was encoded.
1216   if (entry.hasCustomEncoding) {
1217     if (failed(parseCustomEntry(entry, reader, entryType)))
1218       return T();
1219   } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1220     return T();
1221   }
1222 
1223   if (!reader.empty()) {
1224     reader.emitError("unexpected trailing bytes after " + entryType + " entry");
1225     return T();
1226   }
1227   return entry.entry;
1228 }
1229 
1230 template <typename T>
1231 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1232                                             StringRef entryType) {
1233   StringRef asmStr;
1234   if (failed(reader.parseNullTerminatedString(asmStr)))
1235     return failure();
1236 
1237   // Invoke the MLIR assembly parser to parse the entry text.
1238   size_t numRead = 0;
1239   MLIRContext *context = fileLoc->getContext();
1240   if constexpr (std::is_same_v<T, Type>)
1241     result =
1242         ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1243   else
1244     result = ::parseAttribute(asmStr, context, Type(), &numRead,
1245                               /*isKnownNullTerminated=*/true);
1246   if (!result)
1247     return failure();
1248 
1249   // Ensure there weren't dangling characters after the entry.
1250   if (numRead != asmStr.size()) {
1251     return reader.emitError("trailing characters found after ", entryType,
1252                             " assembly format: ", asmStr.drop_front(numRead));
1253   }
1254   return success();
1255 }
1256 
1257 template <typename T>
1258 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1259                                                EncodingReader &reader,
1260                                                StringRef entryType) {
1261   DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1262                               reader, bytecodeVersion);
1263   if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1264     return failure();
1265 
1266   if constexpr (std::is_same_v<T, Type>) {
1267     // Try parsing with callbacks first if available.
1268     for (const auto &callback :
1269          parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
1270       if (failed(
1271               callback->read(dialectReader, entry.dialect->name, entry.entry)))
1272         return failure();
1273       // Early return if parsing was successful.
1274       if (!!entry.entry)
1275         return success();
1276 
1277       // Reset the reader if we failed to parse, so we can fall through the
1278       // other parsing functions.
1279       reader = EncodingReader(entry.data, reader.getLoc());
1280     }
1281   } else {
1282     // Try parsing with callbacks first if available.
1283     for (const auto &callback :
1284          parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
1285       if (failed(
1286               callback->read(dialectReader, entry.dialect->name, entry.entry)))
1287         return failure();
1288       // Early return if parsing was successful.
1289       if (!!entry.entry)
1290         return success();
1291 
1292       // Reset the reader if we failed to parse, so we can fall through the
1293       // other parsing functions.
1294       reader = EncodingReader(entry.data, reader.getLoc());
1295     }
1296   }
1297 
1298   // Ensure that the dialect implements the bytecode interface.
1299   if (!entry.dialect->interface) {
1300     return reader.emitError("dialect '", entry.dialect->name,
1301                             "' does not implement the bytecode interface");
1302   }
1303 
1304   if constexpr (std::is_same_v<T, Type>)
1305     entry.entry = entry.dialect->interface->readType(dialectReader);
1306   else
1307     entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1308 
1309   return success(!!entry.entry);
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // Bytecode Reader
1314 //===----------------------------------------------------------------------===//
1315 
1316 /// This class is used to read a bytecode buffer and translate it into MLIR.
1317 class mlir::BytecodeReader::Impl {
1318   struct RegionReadState;
1319   using LazyLoadableOpsInfo =
1320       std::list<std::pair<Operation *, RegionReadState>>;
1321   using LazyLoadableOpsMap =
1322       DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
1323 
1324 public:
1325   Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1326        llvm::MemoryBufferRef buffer,
1327        const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1328       : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1329         attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1330                        fileLoc, config),
1331         // Use the builtin unrealized conversion cast operation to represent
1332         // forward references to values that aren't yet defined.
1333         forwardRefOpState(UnknownLoc::get(config.getContext()),
1334                           "builtin.unrealized_conversion_cast", ValueRange(),
1335                           NoneType::get(config.getContext())),
1336         buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1337 
1338   /// Read the bytecode defined within `buffer` into the given block.
1339   LogicalResult read(Block *block,
1340                      llvm::function_ref<bool(Operation *)> lazyOps);
1341 
1342   /// Return the number of ops that haven't been materialized yet.
1343   int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1344 
1345   bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
1346 
1347   /// Materialize the provided operation, invoke the lazyOpsCallback on every
1348   /// newly found lazy operation.
1349   LogicalResult
1350   materialize(Operation *op,
1351               llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1352     this->lazyOpsCallback = lazyOpsCallback;
1353     auto resetlazyOpsCallback =
1354         llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1355     auto it = lazyLoadableOpsMap.find(op);
1356     assert(it != lazyLoadableOpsMap.end() &&
1357            "materialize called on non-materializable op");
1358     return materialize(it);
1359   }
1360 
1361   /// Materialize all operations.
1362   LogicalResult materializeAll() {
1363     while (!lazyLoadableOpsMap.empty()) {
1364       if (failed(materialize(lazyLoadableOpsMap.begin())))
1365         return failure();
1366     }
1367     return success();
1368   }
1369 
1370   /// Finalize the lazy-loading by calling back with every op that hasn't been
1371   /// materialized to let the client decide if the op should be deleted or
1372   /// materialized. The op is materialized if the callback returns true, deleted
1373   /// otherwise.
1374   LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1375     while (!lazyLoadableOps.empty()) {
1376       Operation *op = lazyLoadableOps.begin()->first;
1377       if (shouldMaterialize(op)) {
1378         if (failed(materialize(lazyLoadableOpsMap.find(op))))
1379           return failure();
1380         continue;
1381       }
1382       op->dropAllReferences();
1383       op->erase();
1384       lazyLoadableOps.pop_front();
1385       lazyLoadableOpsMap.erase(op);
1386     }
1387     return success();
1388   }
1389 
1390 private:
1391   LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1392     assert(it != lazyLoadableOpsMap.end() &&
1393            "materialize called on non-materializable op");
1394     valueScopes.emplace_back();
1395     std::vector<RegionReadState> regionStack;
1396     regionStack.push_back(std::move(it->getSecond()->second));
1397     lazyLoadableOps.erase(it->getSecond());
1398     lazyLoadableOpsMap.erase(it);
1399 
1400     while (!regionStack.empty())
1401       if (failed(parseRegions(regionStack, regionStack.back())))
1402         return failure();
1403     return success();
1404   }
1405 
1406   /// Return the context for this config.
1407   MLIRContext *getContext() const { return config.getContext(); }
1408 
1409   /// Parse the bytecode version.
1410   LogicalResult parseVersion(EncodingReader &reader);
1411 
1412   //===--------------------------------------------------------------------===//
1413   // Dialect Section
1414 
1415   LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1416 
1417   /// Parse an operation name reference using the given reader, and set the
1418   /// `wasRegistered` flag that indicates if the bytecode was produced by a
1419   /// context where opName was registered.
1420   FailureOr<OperationName> parseOpName(EncodingReader &reader,
1421                                        std::optional<bool> &wasRegistered);
1422 
1423   //===--------------------------------------------------------------------===//
1424   // Attribute/Type Section
1425 
1426   /// Parse an attribute or type using the given reader.
1427   template <typename T>
1428   LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1429     return attrTypeReader.parseAttribute(reader, result);
1430   }
1431   LogicalResult parseType(EncodingReader &reader, Type &result) {
1432     return attrTypeReader.parseType(reader, result);
1433   }
1434 
1435   //===--------------------------------------------------------------------===//
1436   // Resource Section
1437 
1438   LogicalResult
1439   parseResourceSection(EncodingReader &reader,
1440                        std::optional<ArrayRef<uint8_t>> resourceData,
1441                        std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1442 
1443   //===--------------------------------------------------------------------===//
1444   // IR Section
1445 
1446   /// This struct represents the current read state of a range of regions. This
1447   /// struct is used to enable iterative parsing of regions.
1448   struct RegionReadState {
1449     RegionReadState(Operation *op, EncodingReader *reader,
1450                     bool isIsolatedFromAbove)
1451         : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1452     RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1453                     bool isIsolatedFromAbove)
1454         : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1455           isIsolatedFromAbove(isIsolatedFromAbove) {}
1456 
1457     /// The current regions being read.
1458     MutableArrayRef<Region>::iterator curRegion, endRegion;
1459     /// This is the reader to use for this region, this pointer is pointing to
1460     /// the parent region reader unless the current region is IsolatedFromAbove,
1461     /// in which case the pointer is pointing to the `owningReader` which is a
1462     /// section dedicated to the current region.
1463     EncodingReader *reader;
1464     std::unique_ptr<EncodingReader> owningReader;
1465 
1466     /// The number of values defined immediately within this region.
1467     unsigned numValues = 0;
1468 
1469     /// The current blocks of the region being read.
1470     SmallVector<Block *> curBlocks;
1471     Region::iterator curBlock = {};
1472 
1473     /// The number of operations remaining to be read from the current block
1474     /// being read.
1475     uint64_t numOpsRemaining = 0;
1476 
1477     /// A flag indicating if the regions being read are isolated from above.
1478     bool isIsolatedFromAbove = false;
1479   };
1480 
1481   LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1482   LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1483                              RegionReadState &readState);
1484   FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1485                                                RegionReadState &readState,
1486                                                bool &isIsolatedFromAbove);
1487 
1488   LogicalResult parseRegion(RegionReadState &readState);
1489   LogicalResult parseBlockHeader(EncodingReader &reader,
1490                                  RegionReadState &readState);
1491   LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1492 
1493   //===--------------------------------------------------------------------===//
1494   // Value Processing
1495 
1496   /// Parse an operand reference using the given reader. Returns nullptr in the
1497   /// case of failure.
1498   Value parseOperand(EncodingReader &reader);
1499 
1500   /// Sequentially define the given value range.
1501   LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1502 
1503   /// Create a value to use for a forward reference.
1504   Value createForwardRef();
1505 
1506   //===--------------------------------------------------------------------===//
1507   // Use-list order helpers
1508 
1509   /// This struct is a simple storage that contains information required to
1510   /// reorder the use-list of a value with respect to the pre-order traversal
1511   /// ordering.
1512   struct UseListOrderStorage {
1513     UseListOrderStorage(bool isIndexPairEncoding,
1514                         SmallVector<unsigned, 4> &&indices)
1515         : indices(std::move(indices)),
1516           isIndexPairEncoding(isIndexPairEncoding){};
1517     /// The vector containing the information required to reorder the
1518     /// use-list of a value.
1519     SmallVector<unsigned, 4> indices;
1520 
1521     /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1522     /// indexing, such as `dst = order[src]`.
1523     bool isIndexPairEncoding;
1524   };
1525 
1526   /// Parse use-list order from bytecode for a range of values if available. The
1527   /// range is expected to be either a block argument or an op result range. On
1528   /// success, return a map of the position in the range and the use-list order
1529   /// encoding. The function assumes to know the size of the range it is
1530   /// processing.
1531   using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1532   FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1533                                                    uint64_t rangeSize);
1534 
1535   /// Shuffle the use-chain according to the order parsed.
1536   LogicalResult sortUseListOrder(Value value);
1537 
1538   /// Recursively visit all the values defined within topLevelOp and sort the
1539   /// use-list orders according to the indices parsed.
1540   LogicalResult processUseLists(Operation *topLevelOp);
1541 
1542   //===--------------------------------------------------------------------===//
1543   // Fields
1544 
1545   /// This class represents a single value scope, in which a value scope is
1546   /// delimited by isolated from above regions.
1547   struct ValueScope {
1548     /// Push a new region state onto this scope, reserving enough values for
1549     /// those defined within the current region of the provided state.
1550     void push(RegionReadState &readState) {
1551       nextValueIDs.push_back(values.size());
1552       values.resize(values.size() + readState.numValues);
1553     }
1554 
1555     /// Pop the values defined for the current region within the provided region
1556     /// state.
1557     void pop(RegionReadState &readState) {
1558       values.resize(values.size() - readState.numValues);
1559       nextValueIDs.pop_back();
1560     }
1561 
1562     /// The set of values defined in this scope.
1563     std::vector<Value> values;
1564 
1565     /// The ID for the next defined value for each region current being
1566     /// processed in this scope.
1567     SmallVector<unsigned, 4> nextValueIDs;
1568   };
1569 
1570   /// The configuration of the parser.
1571   const ParserConfig &config;
1572 
1573   /// A location to use when emitting errors.
1574   Location fileLoc;
1575 
1576   /// Flag that indicates if lazyloading is enabled.
1577   bool lazyLoading;
1578 
1579   /// Keep track of operations that have been lazy loaded (their regions haven't
1580   /// been materialized), along with the `RegionReadState` that allows to
1581   /// lazy-load the regions nested under the operation.
1582   LazyLoadableOpsInfo lazyLoadableOps;
1583   LazyLoadableOpsMap lazyLoadableOpsMap;
1584   llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1585 
1586   /// The reader used to process attribute and types within the bytecode.
1587   AttrTypeReader attrTypeReader;
1588 
1589   /// The version of the bytecode being read.
1590   uint64_t version = 0;
1591 
1592   /// The producer of the bytecode being read.
1593   StringRef producer;
1594 
1595   /// The table of IR units referenced within the bytecode file.
1596   SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1597   llvm::StringMap<BytecodeDialect *> dialectsMap;
1598   SmallVector<BytecodeOperationName> opNames;
1599 
1600   /// The reader used to process resources within the bytecode.
1601   ResourceSectionReader resourceReader;
1602 
1603   /// Worklist of values with custom use-list orders to process before the end
1604   /// of the parsing.
1605   DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1606 
1607   /// The table of strings referenced within the bytecode file.
1608   StringSectionReader stringReader;
1609 
1610   /// The table of properties referenced by the operation in the bytecode file.
1611   PropertiesSectionReader propertiesReader;
1612 
1613   /// The current set of available IR value scopes.
1614   std::vector<ValueScope> valueScopes;
1615 
1616   /// The global pre-order operation ordering.
1617   DenseMap<Operation *, unsigned> operationIDs;
1618 
1619   /// A block containing the set of operations defined to create forward
1620   /// references.
1621   Block forwardRefOps;
1622 
1623   /// A block containing previously created, and no longer used, forward
1624   /// reference operations.
1625   Block openForwardRefOps;
1626 
1627   /// An operation state used when instantiating forward references.
1628   OperationState forwardRefOpState;
1629 
1630   /// Reference to the input buffer.
1631   llvm::MemoryBufferRef buffer;
1632 
1633   /// The optional owning source manager, which when present may be used to
1634   /// extend the lifetime of the input buffer.
1635   const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1636 };
1637 
1638 LogicalResult BytecodeReader::Impl::read(
1639     Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1640   EncodingReader reader(buffer.getBuffer(), fileLoc);
1641   this->lazyOpsCallback = lazyOpsCallback;
1642   auto resetlazyOpsCallback =
1643       llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1644 
1645   // Skip over the bytecode header, this should have already been checked.
1646   if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
1647     return failure();
1648   // Parse the bytecode version and producer.
1649   if (failed(parseVersion(reader)) ||
1650       failed(reader.parseNullTerminatedString(producer)))
1651     return failure();
1652 
1653   // Add a diagnostic handler that attaches a note that includes the original
1654   // producer of the bytecode.
1655   ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1656     diag.attachNote() << "in bytecode version " << version
1657                       << " produced by: " << producer;
1658     return failure();
1659   });
1660 
1661   // Parse the raw data for each of the top-level sections of the bytecode.
1662   std::optional<ArrayRef<uint8_t>>
1663       sectionDatas[bytecode::Section::kNumSections];
1664   while (!reader.empty()) {
1665     // Read the next section from the bytecode.
1666     bytecode::Section::ID sectionID;
1667     ArrayRef<uint8_t> sectionData;
1668     if (failed(reader.parseSection(sectionID, sectionData)))
1669       return failure();
1670 
1671     // Check for duplicate sections, we only expect one instance of each.
1672     if (sectionDatas[sectionID]) {
1673       return reader.emitError("duplicate top-level section: ",
1674                               ::toString(sectionID));
1675     }
1676     sectionDatas[sectionID] = sectionData;
1677   }
1678   // Check that all of the required sections were found.
1679   for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1680     bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1681     if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
1682       return reader.emitError("missing data for top-level section: ",
1683                               ::toString(sectionID));
1684     }
1685   }
1686 
1687   // Process the string section first.
1688   if (failed(stringReader.initialize(
1689           fileLoc, *sectionDatas[bytecode::Section::kString])))
1690     return failure();
1691 
1692   // Process the properties section.
1693   if (sectionDatas[bytecode::Section::kProperties] &&
1694       failed(propertiesReader.initialize(
1695           fileLoc, *sectionDatas[bytecode::Section::kProperties])))
1696     return failure();
1697 
1698   // Process the dialect section.
1699   if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
1700     return failure();
1701 
1702   // Process the resource section if present.
1703   if (failed(parseResourceSection(
1704           reader, sectionDatas[bytecode::Section::kResource],
1705           sectionDatas[bytecode::Section::kResourceOffset])))
1706     return failure();
1707 
1708   // Process the attribute and type section.
1709   if (failed(attrTypeReader.initialize(
1710           dialects, *sectionDatas[bytecode::Section::kAttrType],
1711           *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1712     return failure();
1713 
1714   // Finally, process the IR section.
1715   return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
1716 }
1717 
1718 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1719   if (failed(reader.parseVarInt(version)))
1720     return failure();
1721 
1722   // Validate the bytecode version.
1723   uint64_t currentVersion = bytecode::kVersion;
1724   uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1725   if (version < minSupportedVersion) {
1726     return reader.emitError("bytecode version ", version,
1727                             " is older than the current version of ",
1728                             currentVersion, ", and upgrade is not supported");
1729   }
1730   if (version > currentVersion) {
1731     return reader.emitError("bytecode version ", version,
1732                             " is newer than the current version ",
1733                             currentVersion);
1734   }
1735   // Override any request to lazy-load if the bytecode version is too old.
1736   if (version < bytecode::kLazyLoading)
1737     lazyLoading = false;
1738   return success();
1739 }
1740 
1741 //===----------------------------------------------------------------------===//
1742 // Dialect Section
1743 
1744 LogicalResult BytecodeDialect::load(const DialectReader &reader,
1745                                     MLIRContext *ctx) {
1746   if (dialect)
1747     return success();
1748   Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1749   if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1750     return reader.emitError("dialect '")
1751            << name
1752            << "' is unknown. If this is intended, please call "
1753               "allowUnregisteredDialects() on the MLIRContext, or use "
1754               "-allow-unregistered-dialect with the MLIR tool used.";
1755   }
1756   dialect = loadedDialect;
1757 
1758   // If the dialect was actually loaded, check to see if it has a bytecode
1759   // interface.
1760   if (loadedDialect)
1761     interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
1762   if (!versionBuffer.empty()) {
1763     if (!interface)
1764       return reader.emitError("dialect '")
1765              << name
1766              << "' does not implement the bytecode interface, "
1767                 "but found a version entry";
1768     EncodingReader encReader(versionBuffer, reader.getLoc());
1769     DialectReader versionReader = reader.withEncodingReader(encReader);
1770     loadedVersion = interface->readVersion(versionReader);
1771     if (!loadedVersion)
1772       return failure();
1773   }
1774   return success();
1775 }
1776 
1777 LogicalResult
1778 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1779   EncodingReader sectionReader(sectionData, fileLoc);
1780 
1781   // Parse the number of dialects in the section.
1782   uint64_t numDialects;
1783   if (failed(sectionReader.parseVarInt(numDialects)))
1784     return failure();
1785   dialects.resize(numDialects);
1786 
1787   // Parse each of the dialects.
1788   for (uint64_t i = 0; i < numDialects; ++i) {
1789     dialects[i] = std::make_unique<BytecodeDialect>();
1790     /// Before version kDialectVersioning, there wasn't any versioning available
1791     /// for dialects, and the entryIdx represent the string itself.
1792     if (version < bytecode::kDialectVersioning) {
1793       if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
1794         return failure();
1795       continue;
1796     }
1797 
1798     // Parse ID representing dialect and version.
1799     uint64_t dialectNameIdx;
1800     bool versionAvailable;
1801     if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
1802                                                  versionAvailable)))
1803       return failure();
1804     if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
1805                                                dialects[i]->name)))
1806       return failure();
1807     if (versionAvailable) {
1808       bytecode::Section::ID sectionID;
1809       if (failed(sectionReader.parseSection(sectionID,
1810                                             dialects[i]->versionBuffer)))
1811         return failure();
1812       if (sectionID != bytecode::Section::kDialectVersions) {
1813         emitError(fileLoc, "expected dialect version section");
1814         return failure();
1815       }
1816     }
1817     dialectsMap[dialects[i]->name] = dialects[i].get();
1818   }
1819 
1820   // Parse the operation names, which are grouped by dialect.
1821   auto parseOpName = [&](BytecodeDialect *dialect) {
1822     StringRef opName;
1823     std::optional<bool> wasRegistered;
1824     // Prior to version kNativePropertiesEncoding, the information about wheter
1825     // an op was registered or not wasn't encoded.
1826     if (version < bytecode::kNativePropertiesEncoding) {
1827       if (failed(stringReader.parseString(sectionReader, opName)))
1828         return failure();
1829     } else {
1830       bool wasRegisteredFlag;
1831       if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
1832                                                   wasRegisteredFlag)))
1833         return failure();
1834       wasRegistered = wasRegisteredFlag;
1835     }
1836     opNames.emplace_back(dialect, opName, wasRegistered);
1837     return success();
1838   };
1839   // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
1840   // where the number of ops are known.
1841   if (version >= bytecode::kElideUnknownBlockArgLocation) {
1842     uint64_t numOps;
1843     if (failed(sectionReader.parseVarInt(numOps)))
1844       return failure();
1845     opNames.reserve(numOps);
1846   }
1847   while (!sectionReader.empty())
1848     if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
1849       return failure();
1850   return success();
1851 }
1852 
1853 FailureOr<OperationName>
1854 BytecodeReader::Impl::parseOpName(EncodingReader &reader,
1855                                   std::optional<bool> &wasRegistered) {
1856   BytecodeOperationName *opName = nullptr;
1857   if (failed(parseEntry(reader, opNames, opName, "operation name")))
1858     return failure();
1859   wasRegistered = opName->wasRegistered;
1860   // Check to see if this operation name has already been resolved. If we
1861   // haven't, load the dialect and build the operation name.
1862   if (!opName->opName) {
1863     // Load the dialect and its version.
1864     DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1865                                 dialectsMap, reader, version);
1866     if (failed(opName->dialect->load(dialectReader, getContext())))
1867       return failure();
1868     // If the opName is empty, this is because we use to accept names such as
1869     // `foo` without any `.` separator. We shouldn't tolerate this in textual
1870     // format anymore but for now we'll be backward compatible. This can only
1871     // happen with unregistered dialects.
1872     if (opName->name.empty()) {
1873       if (opName->dialect->getLoadedDialect())
1874         return emitError(fileLoc) << "has an empty opname for dialect '"
1875                                   << opName->dialect->name << "'\n";
1876 
1877       opName->opName.emplace(opName->dialect->name, getContext());
1878     } else {
1879       opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
1880                              getContext());
1881     }
1882   }
1883   return *opName->opName;
1884 }
1885 
1886 //===----------------------------------------------------------------------===//
1887 // Resource Section
1888 
1889 LogicalResult BytecodeReader::Impl::parseResourceSection(
1890     EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
1891     std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
1892   // Ensure both sections are either present or not.
1893   if (resourceData.has_value() != resourceOffsetData.has_value()) {
1894     if (resourceOffsetData)
1895       return emitError(fileLoc, "unexpected resource offset section when "
1896                                 "resource section is not present");
1897     return emitError(
1898         fileLoc,
1899         "expected resource offset section when resource section is present");
1900   }
1901 
1902   // If the resource sections are absent, there is nothing to do.
1903   if (!resourceData)
1904     return success();
1905 
1906   // Initialize the resource reader with the resource sections.
1907   DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1908                               dialectsMap, reader, version);
1909   return resourceReader.initialize(fileLoc, config, dialects, stringReader,
1910                                    *resourceData, *resourceOffsetData,
1911                                    dialectReader, bufferOwnerRef);
1912 }
1913 
1914 //===----------------------------------------------------------------------===//
1915 // UseListOrder Helpers
1916 
1917 FailureOr<BytecodeReader::Impl::UseListMapT>
1918 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
1919                                                 uint64_t numResults) {
1920   BytecodeReader::Impl::UseListMapT map;
1921   uint64_t numValuesToRead = 1;
1922   if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
1923     return failure();
1924 
1925   for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
1926     uint64_t resultIdx = 0;
1927     if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
1928       return failure();
1929 
1930     uint64_t numValues;
1931     bool indexPairEncoding;
1932     if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
1933       return failure();
1934 
1935     SmallVector<unsigned, 4> useListOrders;
1936     for (size_t idx = 0; idx < numValues; idx++) {
1937       uint64_t index;
1938       if (failed(reader.parseVarInt(index)))
1939         return failure();
1940       useListOrders.push_back(index);
1941     }
1942 
1943     // Store in a map the result index
1944     map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
1945                                                    std::move(useListOrders)));
1946   }
1947 
1948   return map;
1949 }
1950 
1951 /// Sorts each use according to the order specified in the use-list parsed. If
1952 /// the custom use-list is not found, this means that the order needs to be
1953 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie
1954 /// on the same operation, the order will follow the reverse operand number
1955 /// ordering.
1956 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
1957   // Early return for trivial use-lists.
1958   if (value.use_empty() || value.hasOneUse())
1959     return success();
1960 
1961   bool hasIncomingOrder =
1962       valueToUseListMap.contains(value.getAsOpaquePointer());
1963 
1964   // Compute the current order of the use-list with respect to the global
1965   // ordering. Detect if the order is already sorted while doing so.
1966   bool alreadySorted = true;
1967   auto &firstUse = *value.use_begin();
1968   uint64_t prevID =
1969       bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
1970   llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
1971   for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
1972     uint64_t currentID = bytecode::getUseID(
1973         item.value(), operationIDs.at(item.value().getOwner()));
1974     alreadySorted &= prevID > currentID;
1975     currentOrder.push_back({item.index(), currentID});
1976     prevID = currentID;
1977   }
1978 
1979   // If the order is already sorted, and there wasn't a custom order to apply
1980   // from the bytecode file, we are done.
1981   if (alreadySorted && !hasIncomingOrder)
1982     return success();
1983 
1984   // If not already sorted, sort the indices of the current order by descending
1985   // useIDs.
1986   if (!alreadySorted)
1987     std::sort(
1988         currentOrder.begin(), currentOrder.end(),
1989         [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
1990 
1991   if (!hasIncomingOrder) {
1992     // If the bytecode file did not contain any custom use-list order, it means
1993     // that the order was descending useID. Hence, shuffle by the first index
1994     // of the `currentOrder` pair.
1995     SmallVector<unsigned> shuffle = SmallVector<unsigned>(
1996         llvm::map_range(currentOrder, [&](auto item) { return item.first; }));
1997     value.shuffleUseList(shuffle);
1998     return success();
1999   }
2000 
2001   // Pull the custom order info from the map.
2002   UseListOrderStorage customOrder =
2003       valueToUseListMap.at(value.getAsOpaquePointer());
2004   SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2005   uint64_t numUses =
2006       std::distance(value.getUses().begin(), value.getUses().end());
2007 
2008   // If the encoding was a pair of indices `(src, dst)` for every permutation,
2009   // reconstruct the shuffle vector for every use. Initialize the shuffle vector
2010   // as identity, and then apply the mapping encoded in the indices.
2011   if (customOrder.isIndexPairEncoding) {
2012     // Return failure if the number of indices was not representing pairs.
2013     if (shuffle.size() & 1)
2014       return failure();
2015 
2016     SmallVector<unsigned, 4> newShuffle(numUses);
2017     size_t idx = 0;
2018     std::iota(newShuffle.begin(), newShuffle.end(), idx);
2019     for (idx = 0; idx < shuffle.size(); idx += 2)
2020       newShuffle[shuffle[idx]] = shuffle[idx + 1];
2021 
2022     shuffle = std::move(newShuffle);
2023   }
2024 
2025   // Make sure that the indices represent a valid mapping. That is, the sum of
2026   // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2027   // duplicates are allowed in the list.
2028   DenseSet<unsigned> set;
2029   uint64_t accumulator = 0;
2030   for (const auto &elem : shuffle) {
2031     if (set.contains(elem))
2032       return failure();
2033     accumulator += elem;
2034     set.insert(elem);
2035   }
2036   if (numUses != shuffle.size() ||
2037       accumulator != (((numUses - 1) * numUses) >> 1))
2038     return failure();
2039 
2040   // Apply the current ordering map onto the shuffle vector to get the final
2041   // use-list sorting indices before shuffling.
2042   shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2043       currentOrder, [&](auto item) { return shuffle[item.first]; }));
2044   value.shuffleUseList(shuffle);
2045   return success();
2046 }
2047 
2048 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2049   // Precompute operation IDs according to the pre-order walk of the IR. We
2050   // can't do this while parsing since parseRegions ordering is not strictly
2051   // equal to the pre-order walk.
2052   unsigned operationID = 0;
2053   topLevelOp->walk<mlir::WalkOrder::PreOrder>(
2054       [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2055 
2056   auto blockWalk = topLevelOp->walk([this](Block *block) {
2057     for (auto arg : block->getArguments())
2058       if (failed(sortUseListOrder(arg)))
2059         return WalkResult::interrupt();
2060     return WalkResult::advance();
2061   });
2062 
2063   auto resultWalk = topLevelOp->walk([this](Operation *op) {
2064     for (auto result : op->getResults())
2065       if (failed(sortUseListOrder(result)))
2066         return WalkResult::interrupt();
2067     return WalkResult::advance();
2068   });
2069 
2070   return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2071 }
2072 
2073 //===----------------------------------------------------------------------===//
2074 // IR Section
2075 
2076 LogicalResult
2077 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2078                                      Block *block) {
2079   EncodingReader reader(sectionData, fileLoc);
2080 
2081   // A stack of operation regions currently being read from the bytecode.
2082   std::vector<RegionReadState> regionStack;
2083 
2084   // Parse the top-level block using a temporary module operation.
2085   OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2086   regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
2087   regionStack.back().curBlocks.push_back(moduleOp->getBody());
2088   regionStack.back().curBlock = regionStack.back().curRegion->begin();
2089   if (failed(parseBlockHeader(reader, regionStack.back())))
2090     return failure();
2091   valueScopes.emplace_back();
2092   valueScopes.back().push(regionStack.back());
2093 
2094   // Iteratively parse regions until everything has been resolved.
2095   while (!regionStack.empty())
2096     if (failed(parseRegions(regionStack, regionStack.back())))
2097       return failure();
2098   if (!forwardRefOps.empty()) {
2099     return reader.emitError(
2100         "not all forward unresolved forward operand references");
2101   }
2102 
2103   // Sort use-lists according to what specified in bytecode.
2104   if (failed(processUseLists(*moduleOp)))
2105     return reader.emitError(
2106         "parsed use-list orders were invalid and could not be applied");
2107 
2108   // Resolve dialect version.
2109   for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2110     // Parsing is complete, give an opportunity to each dialect to visit the
2111     // IR and perform upgrades.
2112     if (!byteCodeDialect->loadedVersion)
2113       continue;
2114     if (byteCodeDialect->interface &&
2115         failed(byteCodeDialect->interface->upgradeFromVersion(
2116             *moduleOp, *byteCodeDialect->loadedVersion)))
2117       return failure();
2118   }
2119 
2120   // Verify that the parsed operations are valid.
2121   if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
2122     return failure();
2123 
2124   // Splice the parsed operations over to the provided top-level block.
2125   auto &parsedOps = moduleOp->getBody()->getOperations();
2126   auto &destOps = block->getOperations();
2127   destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2128   return success();
2129 }
2130 
2131 LogicalResult
2132 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
2133                                    RegionReadState &readState) {
2134   // Process regions, blocks, and operations until the end or if a nested
2135   // region is encountered. In this case we push a new state in regionStack and
2136   // return, the processing of the current region will resume afterward.
2137   for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2138     // If the current block hasn't been setup yet, parse the header for this
2139     // region. The current block is already setup when this function was
2140     // interrupted to recurse down in a nested region and we resume the current
2141     // block after processing the nested region.
2142     if (readState.curBlock == Region::iterator()) {
2143       if (failed(parseRegion(readState)))
2144         return failure();
2145 
2146       // If the region is empty, there is nothing to more to do.
2147       if (readState.curRegion->empty())
2148         continue;
2149     }
2150 
2151     // Parse the blocks within the region.
2152     EncodingReader &reader = *readState.reader;
2153     do {
2154       while (readState.numOpsRemaining--) {
2155         // Read in the next operation. We don't read its regions directly, we
2156         // handle those afterwards as necessary.
2157         bool isIsolatedFromAbove = false;
2158         FailureOr<Operation *> op =
2159             parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2160         if (failed(op))
2161           return failure();
2162 
2163         // If the op has regions, add it to the stack for processing and return:
2164         // we stop the processing of the current region and resume it after the
2165         // inner one is completed. Unless LazyLoading is activated in which case
2166         // nested region parsing is delayed.
2167         if ((*op)->getNumRegions()) {
2168           RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2169 
2170           // Isolated regions are encoded as a section in version 2 and above.
2171           if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2172             bytecode::Section::ID sectionID;
2173             ArrayRef<uint8_t> sectionData;
2174             if (failed(reader.parseSection(sectionID, sectionData)))
2175               return failure();
2176             if (sectionID != bytecode::Section::kIR)
2177               return emitError(fileLoc, "expected IR section for region");
2178             childState.owningReader =
2179                 std::make_unique<EncodingReader>(sectionData, fileLoc);
2180             childState.reader = childState.owningReader.get();
2181 
2182             // If the user has a callback set, they have the opportunity to
2183             // control lazyloading as we go.
2184             if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2185               lazyLoadableOps.emplace_back(*op, std::move(childState));
2186               lazyLoadableOpsMap.try_emplace(*op,
2187                                              std::prev(lazyLoadableOps.end()));
2188               continue;
2189             }
2190           }
2191           regionStack.push_back(std::move(childState));
2192 
2193           // If the op is isolated from above, push a new value scope.
2194           if (isIsolatedFromAbove)
2195             valueScopes.emplace_back();
2196           return success();
2197         }
2198       }
2199 
2200       // Move to the next block of the region.
2201       if (++readState.curBlock == readState.curRegion->end())
2202         break;
2203       if (failed(parseBlockHeader(reader, readState)))
2204         return failure();
2205     } while (true);
2206 
2207     // Reset the current block and any values reserved for this region.
2208     readState.curBlock = {};
2209     valueScopes.back().pop(readState);
2210   }
2211 
2212   // When the regions have been fully parsed, pop them off of the read stack. If
2213   // the regions were isolated from above, we also pop the last value scope.
2214   if (readState.isIsolatedFromAbove) {
2215     assert(!valueScopes.empty() && "Expect a valueScope after reading region");
2216     valueScopes.pop_back();
2217   }
2218   assert(!regionStack.empty() && "Expect a regionStack after reading region");
2219   regionStack.pop_back();
2220   return success();
2221 }
2222 
2223 FailureOr<Operation *>
2224 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2225                                             RegionReadState &readState,
2226                                             bool &isIsolatedFromAbove) {
2227   // Parse the name of the operation.
2228   std::optional<bool> wasRegistered;
2229   FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2230   if (failed(opName))
2231     return failure();
2232 
2233   // Parse the operation mask, which indicates which components of the operation
2234   // are present.
2235   uint8_t opMask;
2236   if (failed(reader.parseByte(opMask)))
2237     return failure();
2238 
2239   /// Parse the location.
2240   LocationAttr opLoc;
2241   if (failed(parseAttribute(reader, opLoc)))
2242     return failure();
2243 
2244   // With the location and name resolved, we can start building the operation
2245   // state.
2246   OperationState opState(opLoc, *opName);
2247 
2248   // Parse the attributes of the operation.
2249   if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
2250     DictionaryAttr dictAttr;
2251     if (failed(parseAttribute(reader, dictAttr)))
2252       return failure();
2253     opState.attributes = dictAttr;
2254   }
2255 
2256   if (opMask & bytecode::OpEncodingMask::kHasProperties) {
2257     // kHasProperties wasn't emitted in older bytecode, we should never get
2258     // there without also having the `wasRegistered` flag available.
2259     if (!wasRegistered)
2260       return emitError(fileLoc,
2261                        "Unexpected missing `wasRegistered` opname flag at "
2262                        "bytecode version ")
2263              << version << " with properties.";
2264     // When an operation is emitted without being registered, the properties are
2265     // stored as an attribute. Otherwise the op must implement the bytecode
2266     // interface and control the serialization.
2267     if (wasRegistered) {
2268       DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2269                                   dialectsMap, reader, version);
2270       if (failed(
2271               propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2272         return failure();
2273     } else {
2274       // If the operation wasn't registered when it was emitted, the properties
2275       // was serialized as an attribute.
2276       if (failed(parseAttribute(reader, opState.propertiesAttr)))
2277         return failure();
2278     }
2279   }
2280 
2281   /// Parse the results of the operation.
2282   if (opMask & bytecode::OpEncodingMask::kHasResults) {
2283     uint64_t numResults;
2284     if (failed(reader.parseVarInt(numResults)))
2285       return failure();
2286     opState.types.resize(numResults);
2287     for (int i = 0, e = numResults; i < e; ++i)
2288       if (failed(parseType(reader, opState.types[i])))
2289         return failure();
2290   }
2291 
2292   /// Parse the operands of the operation.
2293   if (opMask & bytecode::OpEncodingMask::kHasOperands) {
2294     uint64_t numOperands;
2295     if (failed(reader.parseVarInt(numOperands)))
2296       return failure();
2297     opState.operands.resize(numOperands);
2298     for (int i = 0, e = numOperands; i < e; ++i)
2299       if (!(opState.operands[i] = parseOperand(reader)))
2300         return failure();
2301   }
2302 
2303   /// Parse the successors of the operation.
2304   if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
2305     uint64_t numSuccs;
2306     if (failed(reader.parseVarInt(numSuccs)))
2307       return failure();
2308     opState.successors.resize(numSuccs);
2309     for (int i = 0, e = numSuccs; i < e; ++i) {
2310       if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
2311                             "successor")))
2312         return failure();
2313     }
2314   }
2315 
2316   /// Parse the use-list orders for the results of the operation. Use-list
2317   /// orders are available since version 3 of the bytecode.
2318   std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2319   if (version >= bytecode::kUseListOrdering &&
2320       (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
2321     size_t numResults = opState.types.size();
2322     auto parseResult = parseUseListOrderForRange(reader, numResults);
2323     if (failed(parseResult))
2324       return failure();
2325     resultIdxToUseListMap = std::move(*parseResult);
2326   }
2327 
2328   /// Parse the regions of the operation.
2329   if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
2330     uint64_t numRegions;
2331     if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2332       return failure();
2333 
2334     opState.regions.reserve(numRegions);
2335     for (int i = 0, e = numRegions; i < e; ++i)
2336       opState.regions.push_back(std::make_unique<Region>());
2337   }
2338 
2339   // Create the operation at the back of the current block.
2340   Operation *op = Operation::create(opState);
2341   readState.curBlock->push_back(op);
2342 
2343   // If the operation had results, update the value references.
2344   if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
2345     return failure();
2346 
2347   /// Store a map for every value that received a custom use-list order from the
2348   /// bytecode file.
2349   if (resultIdxToUseListMap.has_value()) {
2350     for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2351       if (resultIdxToUseListMap->contains(idx)) {
2352         valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
2353                                       resultIdxToUseListMap->at(idx));
2354       }
2355     }
2356   }
2357   return op;
2358 }
2359 
2360 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2361   EncodingReader &reader = *readState.reader;
2362 
2363   // Parse the number of blocks in the region.
2364   uint64_t numBlocks;
2365   if (failed(reader.parseVarInt(numBlocks)))
2366     return failure();
2367 
2368   // If the region is empty, there is nothing else to do.
2369   if (numBlocks == 0)
2370     return success();
2371 
2372   // Parse the number of values defined in this region.
2373   uint64_t numValues;
2374   if (failed(reader.parseVarInt(numValues)))
2375     return failure();
2376   readState.numValues = numValues;
2377 
2378   // Create the blocks within this region. We do this before processing so that
2379   // we can rely on the blocks existing when creating operations.
2380   readState.curBlocks.clear();
2381   readState.curBlocks.reserve(numBlocks);
2382   for (uint64_t i = 0; i < numBlocks; ++i) {
2383     readState.curBlocks.push_back(new Block());
2384     readState.curRegion->push_back(readState.curBlocks.back());
2385   }
2386 
2387   // Prepare the current value scope for this region.
2388   valueScopes.back().push(readState);
2389 
2390   // Parse the entry block of the region.
2391   readState.curBlock = readState.curRegion->begin();
2392   return parseBlockHeader(reader, readState);
2393 }
2394 
2395 LogicalResult
2396 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2397                                        RegionReadState &readState) {
2398   bool hasArgs;
2399   if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2400     return failure();
2401 
2402   // Parse the arguments of the block.
2403   if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
2404     return failure();
2405 
2406   // Uselist orders are available since version 3 of the bytecode.
2407   if (version < bytecode::kUseListOrdering)
2408     return success();
2409 
2410   uint8_t hasUseListOrders = 0;
2411   if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
2412     return failure();
2413 
2414   if (!hasUseListOrders)
2415     return success();
2416 
2417   Block &blk = *readState.curBlock;
2418   auto argIdxToUseListMap =
2419       parseUseListOrderForRange(reader, blk.getNumArguments());
2420   if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2421     return failure();
2422 
2423   for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2424     if (argIdxToUseListMap->contains(idx))
2425       valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
2426                                     argIdxToUseListMap->at(idx));
2427 
2428   // We don't parse the operations of the block here, that's done elsewhere.
2429   return success();
2430 }
2431 
2432 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2433                                                         Block *block) {
2434   // Parse the value ID for the first argument, and the number of arguments.
2435   uint64_t numArgs;
2436   if (failed(reader.parseVarInt(numArgs)))
2437     return failure();
2438 
2439   SmallVector<Type> argTypes;
2440   SmallVector<Location> argLocs;
2441   argTypes.reserve(numArgs);
2442   argLocs.reserve(numArgs);
2443 
2444   Location unknownLoc = UnknownLoc::get(config.getContext());
2445   while (numArgs--) {
2446     Type argType;
2447     LocationAttr argLoc = unknownLoc;
2448     if (version >= bytecode::kElideUnknownBlockArgLocation) {
2449       // Parse the type with hasLoc flag to determine if it has type.
2450       uint64_t typeIdx;
2451       bool hasLoc;
2452       if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2453           !(argType = attrTypeReader.resolveType(typeIdx)))
2454         return failure();
2455       if (hasLoc && failed(parseAttribute(reader, argLoc)))
2456         return failure();
2457     } else {
2458       // All args has type and location.
2459       if (failed(parseType(reader, argType)) ||
2460           failed(parseAttribute(reader, argLoc)))
2461         return failure();
2462     }
2463     argTypes.push_back(argType);
2464     argLocs.push_back(argLoc);
2465   }
2466   block->addArguments(argTypes, argLocs);
2467   return defineValues(reader, block->getArguments());
2468 }
2469 
2470 //===----------------------------------------------------------------------===//
2471 // Value Processing
2472 
2473 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2474   std::vector<Value> &values = valueScopes.back().values;
2475   Value *value = nullptr;
2476   if (failed(parseEntry(reader, values, value, "value")))
2477     return Value();
2478 
2479   // Create a new forward reference if necessary.
2480   if (!*value)
2481     *value = createForwardRef();
2482   return *value;
2483 }
2484 
2485 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2486                                                  ValueRange newValues) {
2487   ValueScope &valueScope = valueScopes.back();
2488   std::vector<Value> &values = valueScope.values;
2489 
2490   unsigned &valueID = valueScope.nextValueIDs.back();
2491   unsigned valueIDEnd = valueID + newValues.size();
2492   if (valueIDEnd > values.size()) {
2493     return reader.emitError(
2494         "value index range was outside of the expected range for "
2495         "the parent region, got [",
2496         valueID, ", ", valueIDEnd, "), but the maximum index was ",
2497         values.size() - 1);
2498   }
2499 
2500   // Assign the values and update any forward references.
2501   for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2502     Value newValue = newValues[i];
2503 
2504     // Check to see if a definition for this value already exists.
2505     if (Value oldValue = std::exchange(values[valueID], newValue)) {
2506       Operation *forwardRefOp = oldValue.getDefiningOp();
2507 
2508       // Assert that this is a forward reference operation. Given how we compute
2509       // definition ids (incrementally as we parse), it shouldn't be possible
2510       // for the value to be defined any other way.
2511       assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2512              "value index was already defined?");
2513 
2514       oldValue.replaceAllUsesWith(newValue);
2515       forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
2516     }
2517   }
2518   return success();
2519 }
2520 
2521 Value BytecodeReader::Impl::createForwardRef() {
2522   // Check for an avaliable existing operation to use. Otherwise, create a new
2523   // fake operation to use for the reference.
2524   if (!openForwardRefOps.empty()) {
2525     Operation *op = &openForwardRefOps.back();
2526     op->moveBefore(&forwardRefOps, forwardRefOps.end());
2527   } else {
2528     forwardRefOps.push_back(Operation::create(forwardRefOpState));
2529   }
2530   return forwardRefOps.back().getResult(0);
2531 }
2532 
2533 //===----------------------------------------------------------------------===//
2534 // Entry Points
2535 //===----------------------------------------------------------------------===//
2536 
2537 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
2538 
2539 BytecodeReader::BytecodeReader(
2540     llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2541     const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2542   Location sourceFileLoc =
2543       FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2544                           /*line=*/0, /*column=*/0);
2545   impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
2546                                 bufferOwnerRef);
2547 }
2548 
2549 LogicalResult BytecodeReader::readTopLevel(
2550     Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2551   return impl->read(block, lazyOpsCallback);
2552 }
2553 
2554 int64_t BytecodeReader::getNumOpsToMaterialize() const {
2555   return impl->getNumOpsToMaterialize();
2556 }
2557 
2558 bool BytecodeReader::isMaterializable(Operation *op) {
2559   return impl->isMaterializable(op);
2560 }
2561 
2562 LogicalResult BytecodeReader::materialize(
2563     Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2564   return impl->materialize(op, lazyOpsCallback);
2565 }
2566 
2567 LogicalResult
2568 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
2569   return impl->finalize(shouldMaterialize);
2570 }
2571 
2572 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2573   return buffer.getBuffer().startswith("ML\xefR");
2574 }
2575 
2576 /// Read the bytecode from the provided memory buffer reference.
2577 /// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2578 /// and may be used to extend the lifetime of the buffer.
2579 static LogicalResult
2580 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2581                      const ParserConfig &config,
2582                      const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2583   Location sourceFileLoc =
2584       FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2585                           /*line=*/0, /*column=*/0);
2586   if (!isBytecode(buffer)) {
2587     return emitError(sourceFileLoc,
2588                      "input buffer is not an MLIR bytecode file");
2589   }
2590 
2591   BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2592                               buffer, bufferOwnerRef);
2593   return reader.read(block, /*lazyOpsCallback=*/nullptr);
2594 }
2595 
2596 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2597                                      const ParserConfig &config) {
2598   return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2599 }
2600 LogicalResult
2601 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2602                        Block *block, const ParserConfig &config) {
2603   return readBytecodeFileImpl(
2604       *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
2605       sourceMgr);
2606 }
2607