xref: /llvm-project/mlir/lib/Bytecode/Reader/BytecodeReader.cpp (revision bb9a0c736b57f405c6fee598ce8043d0d35a5790)
1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // TODO: Support for big-endian architectures.
10 
11 #include "mlir/Bytecode/BytecodeReader.h"
12 #include "mlir/AsmParser/AsmParser.h"
13 #include "mlir/Bytecode/BytecodeImplementation.h"
14 #include "mlir/Bytecode/Encoding.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/IR/Visitors.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/MemoryBufferRef.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/SourceMgr.h"
31 #include <list>
32 #include <memory>
33 #include <numeric>
34 #include <optional>
35 
36 #define DEBUG_TYPE "mlir-bytecode-reader"
37 
38 using namespace mlir;
39 
40 /// Stringify the given section ID.
41 static std::string toString(bytecode::Section::ID sectionID) {
42   switch (sectionID) {
43   case bytecode::Section::kString:
44     return "String (0)";
45   case bytecode::Section::kDialect:
46     return "Dialect (1)";
47   case bytecode::Section::kAttrType:
48     return "AttrType (2)";
49   case bytecode::Section::kAttrTypeOffset:
50     return "AttrTypeOffset (3)";
51   case bytecode::Section::kIR:
52     return "IR (4)";
53   case bytecode::Section::kResource:
54     return "Resource (5)";
55   case bytecode::Section::kResourceOffset:
56     return "ResourceOffset (6)";
57   case bytecode::Section::kDialectVersions:
58     return "DialectVersions (7)";
59   default:
60     return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
61   }
62 }
63 
64 /// Returns true if the given top-level section ID is optional.
65 static bool isSectionOptional(bytecode::Section::ID sectionID) {
66   switch (sectionID) {
67   case bytecode::Section::kString:
68   case bytecode::Section::kDialect:
69   case bytecode::Section::kAttrType:
70   case bytecode::Section::kAttrTypeOffset:
71   case bytecode::Section::kIR:
72     return false;
73   case bytecode::Section::kResource:
74   case bytecode::Section::kResourceOffset:
75   case bytecode::Section::kDialectVersions:
76     return true;
77   default:
78     llvm_unreachable("unknown section ID");
79   }
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // EncodingReader
84 //===----------------------------------------------------------------------===//
85 
86 namespace {
87 class EncodingReader {
88 public:
89   explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
90       : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
91   explicit EncodingReader(StringRef contents, Location fileLoc)
92       : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
93                         contents.size()},
94                        fileLoc) {}
95 
96   /// Returns true if the entire section has been read.
97   bool empty() const { return dataIt == dataEnd; }
98 
99   /// Returns the remaining size of the bytecode.
100   size_t size() const { return dataEnd - dataIt; }
101 
102   /// Align the current reader position to the specified alignment.
103   LogicalResult alignTo(unsigned alignment) {
104     if (!llvm::isPowerOf2_32(alignment))
105       return emitError("expected alignment to be a power-of-two");
106 
107     // Shift the reader position to the next alignment boundary.
108     while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
109       uint8_t padding;
110       if (failed(parseByte(padding)))
111         return failure();
112       if (padding != bytecode::kAlignmentByte) {
113         return emitError("expected alignment byte (0xCB), but got: '0x" +
114                          llvm::utohexstr(padding) + "'");
115       }
116     }
117 
118     // Ensure the data iterator is now aligned. This case is unlikely because we
119     // *just* went through the effort to align the data iterator.
120     if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) {
121       return emitError("expected data iterator aligned to ", alignment,
122                        ", but got pointer: '0x" +
123                            llvm::utohexstr((uintptr_t)dataIt) + "'");
124     }
125 
126     return success();
127   }
128 
129   /// Emit an error using the given arguments.
130   template <typename... Args>
131   InFlightDiagnostic emitError(Args &&...args) const {
132     return ::emitError(fileLoc).append(std::forward<Args>(args)...);
133   }
134   InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
135 
136   /// Parse a single byte from the stream.
137   template <typename T>
138   LogicalResult parseByte(T &value) {
139     if (empty())
140       return emitError("attempting to parse a byte at the end of the bytecode");
141     value = static_cast<T>(*dataIt++);
142     return success();
143   }
144   /// Parse a range of bytes of 'length' into the given result.
145   LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
146     if (length > size()) {
147       return emitError("attempting to parse ", length, " bytes when only ",
148                        size(), " remain");
149     }
150     result = {dataIt, length};
151     dataIt += length;
152     return success();
153   }
154   /// Parse a range of bytes of 'length' into the given result, which can be
155   /// assumed to be large enough to hold `length`.
156   LogicalResult parseBytes(size_t length, uint8_t *result) {
157     if (length > size()) {
158       return emitError("attempting to parse ", length, " bytes when only ",
159                        size(), " remain");
160     }
161     memcpy(result, dataIt, length);
162     dataIt += length;
163     return success();
164   }
165 
166   /// Parse an aligned blob of data, where the alignment was encoded alongside
167   /// the data.
168   LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
169                                       uint64_t &alignment) {
170     uint64_t dataSize;
171     if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
172         failed(alignTo(alignment)))
173       return failure();
174     return parseBytes(dataSize, data);
175   }
176 
177   /// Parse a variable length encoded integer from the byte stream. The first
178   /// encoded byte contains a prefix in the low bits indicating the encoded
179   /// length of the value. This length prefix is a bit sequence of '0's followed
180   /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
181   /// (not including the prefix byte). All remaining bits in the first byte,
182   /// along with all of the bits in additional bytes, provide the value of the
183   /// integer encoded in little-endian order.
184   LogicalResult parseVarInt(uint64_t &result) {
185     // Parse the first byte of the encoding, which contains the length prefix.
186     if (failed(parseByte(result)))
187       return failure();
188 
189     // Handle the overwhelmingly common case where the value is stored in a
190     // single byte. In this case, the first bit is the `1` marker bit.
191     if (LLVM_LIKELY(result & 1)) {
192       result >>= 1;
193       return success();
194     }
195 
196     // Handle the overwhelming uncommon case where the value required all 8
197     // bytes (i.e. a really really big number). In this case, the marker byte is
198     // all zeros: `00000000`.
199     if (LLVM_UNLIKELY(result == 0))
200       return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result));
201     return parseMultiByteVarInt(result);
202   }
203 
204   /// Parse a signed variable length encoded integer from the byte stream. A
205   /// signed varint is encoded as a normal varint with zigzag encoding applied,
206   /// i.e. the low bit of the value is used to indicate the sign.
207   LogicalResult parseSignedVarInt(uint64_t &result) {
208     if (failed(parseVarInt(result)))
209       return failure();
210     // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
211     result = (result >> 1) ^ (~(result & 1) + 1);
212     return success();
213   }
214 
215   /// Parse a variable length encoded integer whose low bit is used to encode an
216   /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
217   LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
218     if (failed(parseVarInt(result)))
219       return failure();
220     flag = result & 1;
221     result >>= 1;
222     return success();
223   }
224 
225   /// Skip the first `length` bytes within the reader.
226   LogicalResult skipBytes(size_t length) {
227     if (length > size()) {
228       return emitError("attempting to skip ", length, " bytes when only ",
229                        size(), " remain");
230     }
231     dataIt += length;
232     return success();
233   }
234 
235   /// Parse a null-terminated string into `result` (without including the NUL
236   /// terminator).
237   LogicalResult parseNullTerminatedString(StringRef &result) {
238     const char *startIt = (const char *)dataIt;
239     const char *nulIt = (const char *)memchr(startIt, 0, size());
240     if (!nulIt)
241       return emitError(
242           "malformed null-terminated string, no null character found");
243 
244     result = StringRef(startIt, nulIt - startIt);
245     dataIt = (const uint8_t *)nulIt + 1;
246     return success();
247   }
248 
249   /// Parse a section header, placing the kind of section in `sectionID` and the
250   /// contents of the section in `sectionData`.
251   LogicalResult parseSection(bytecode::Section::ID &sectionID,
252                              ArrayRef<uint8_t> &sectionData) {
253     uint8_t sectionIDAndHasAlignment;
254     uint64_t length;
255     if (failed(parseByte(sectionIDAndHasAlignment)) ||
256         failed(parseVarInt(length)))
257       return failure();
258 
259     // Extract the section ID and whether the section is aligned. The high bit
260     // of the ID is the alignment flag.
261     sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
262                                                    0b01111111);
263     bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
264 
265     // Check that the section is actually valid before trying to process its
266     // data.
267     if (sectionID >= bytecode::Section::kNumSections)
268       return emitError("invalid section ID: ", unsigned(sectionID));
269 
270     // Process the section alignment if present.
271     if (hasAlignment) {
272       uint64_t alignment;
273       if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
274         return failure();
275     }
276 
277     // Parse the actual section data.
278     return parseBytes(static_cast<size_t>(length), sectionData);
279   }
280 
281   Location getLoc() const { return fileLoc; }
282 
283 private:
284   /// Parse a variable length encoded integer from the byte stream. This method
285   /// is a fallback when the number of bytes used to encode the value is greater
286   /// than 1, but less than the max (9). The provided `result` value can be
287   /// assumed to already contain the first byte of the value.
288   /// NOTE: This method is marked noinline to avoid pessimizing the common case
289   /// of single byte encoding.
290   LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
291     // Count the number of trailing zeros in the marker byte, this indicates the
292     // number of trailing bytes that are part of the value. We use `uint32_t`
293     // here because we only care about the first byte, and so that be actually
294     // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
295     // implementation).
296     uint32_t numBytes = llvm::countr_zero<uint32_t>(result);
297     assert(numBytes > 0 && numBytes <= 7 &&
298            "unexpected number of trailing zeros in varint encoding");
299 
300     // Parse in the remaining bytes of the value.
301     if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1)))
302       return failure();
303 
304     // Shift out the low-order bits that were used to mark how the value was
305     // encoded.
306     result >>= (numBytes + 1);
307     return success();
308   }
309 
310   /// The current data iterator, and an iterator to the end of the buffer.
311   const uint8_t *dataIt, *dataEnd;
312 
313   /// A location for the bytecode used to report errors.
314   Location fileLoc;
315 };
316 } // namespace
317 
318 /// Resolve an index into the given entry list. `entry` may either be a
319 /// reference, in which case it is assigned to the corresponding value in
320 /// `entries`, or a pointer, in which case it is assigned to the address of the
321 /// element in `entries`.
322 template <typename RangeT, typename T>
323 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
324                                   uint64_t index, T &entry,
325                                   StringRef entryStr) {
326   if (index >= entries.size())
327     return reader.emitError("invalid ", entryStr, " index: ", index);
328 
329   // If the provided entry is a pointer, resolve to the address of the entry.
330   if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
331     entry = entries[index];
332   else
333     entry = &entries[index];
334   return success();
335 }
336 
337 /// Parse and resolve an index into the given entry list.
338 template <typename RangeT, typename T>
339 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
340                                 T &entry, StringRef entryStr) {
341   uint64_t entryIdx;
342   if (failed(reader.parseVarInt(entryIdx)))
343     return failure();
344   return resolveEntry(reader, entries, entryIdx, entry, entryStr);
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // StringSectionReader
349 //===----------------------------------------------------------------------===//
350 
351 namespace {
352 /// This class is used to read references to the string section from the
353 /// bytecode.
354 class StringSectionReader {
355 public:
356   /// Initialize the string section reader with the given section data.
357   LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
358 
359   /// Parse a shared string from the string section. The shared string is
360   /// encoded using an index to a corresponding string in the string section.
361   LogicalResult parseString(EncodingReader &reader, StringRef &result) {
362     return parseEntry(reader, strings, result, "string");
363   }
364 
365   /// Parse a shared string from the string section. The shared string is
366   /// encoded using an index to a corresponding string in the string section.
367   LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
368                                    StringRef &result) {
369     return resolveEntry(reader, strings, index, result, "string");
370   }
371 
372 private:
373   /// The table of strings referenced within the bytecode file.
374   SmallVector<StringRef> strings;
375 };
376 } // namespace
377 
378 LogicalResult StringSectionReader::initialize(Location fileLoc,
379                                               ArrayRef<uint8_t> sectionData) {
380   EncodingReader stringReader(sectionData, fileLoc);
381 
382   // Parse the number of strings in the section.
383   uint64_t numStrings;
384   if (failed(stringReader.parseVarInt(numStrings)))
385     return failure();
386   strings.resize(numStrings);
387 
388   // Parse each of the strings. The sizes of the strings are encoded in reverse
389   // order, so that's the order we populate the table.
390   size_t stringDataEndOffset = sectionData.size();
391   for (StringRef &string : llvm::reverse(strings)) {
392     uint64_t stringSize;
393     if (failed(stringReader.parseVarInt(stringSize)))
394       return failure();
395     if (stringDataEndOffset < stringSize) {
396       return stringReader.emitError(
397           "string size exceeds the available data size");
398     }
399 
400     // Extract the string from the data, dropping the null character.
401     size_t stringOffset = stringDataEndOffset - stringSize;
402     string = StringRef(
403         reinterpret_cast<const char *>(sectionData.data() + stringOffset),
404         stringSize - 1);
405     stringDataEndOffset = stringOffset;
406   }
407 
408   // Check that the only remaining data was for the strings, i.e. the reader
409   // should be at the same offset as the first string.
410   if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
411     return stringReader.emitError("unexpected trailing data between the "
412                                   "offsets for strings and their data");
413   }
414   return success();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // BytecodeDialect
419 //===----------------------------------------------------------------------===//
420 
421 namespace {
422 class DialectReader;
423 
424 /// This struct represents a dialect entry within the bytecode.
425 struct BytecodeDialect {
426   /// Load the dialect into the provided context if it hasn't been loaded yet.
427   /// Returns failure if the dialect couldn't be loaded *and* the provided
428   /// context does not allow unregistered dialects. The provided reader is used
429   /// for error emission if necessary.
430   LogicalResult load(DialectReader &reader, MLIRContext *ctx);
431 
432   /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
433   /// only be called after `load`.
434   Dialect *getLoadedDialect() const {
435     assert(dialect &&
436            "expected `load` to be invoked before `getLoadedDialect`");
437     return *dialect;
438   }
439 
440   /// The loaded dialect entry. This field is std::nullopt if we haven't
441   /// attempted to load, nullptr if we failed to load, otherwise the loaded
442   /// dialect.
443   std::optional<Dialect *> dialect;
444 
445   /// The bytecode interface of the dialect, or nullptr if the dialect does not
446   /// implement the bytecode interface. This field should only be checked if the
447   /// `dialect` field is not std::nullopt.
448   const BytecodeDialectInterface *interface = nullptr;
449 
450   /// The name of the dialect.
451   StringRef name;
452 
453   /// A buffer containing the encoding of the dialect version parsed.
454   ArrayRef<uint8_t> versionBuffer;
455 
456   /// Lazy loaded dialect version from the handle above.
457   std::unique_ptr<DialectVersion> loadedVersion;
458 };
459 
460 /// This struct represents an operation name entry within the bytecode.
461 struct BytecodeOperationName {
462   BytecodeOperationName(BytecodeDialect *dialect, StringRef name)
463       : dialect(dialect), name(name) {}
464 
465   /// The loaded operation name, or std::nullopt if it hasn't been processed
466   /// yet.
467   std::optional<OperationName> opName;
468 
469   /// The dialect that owns this operation name.
470   BytecodeDialect *dialect;
471 
472   /// The name of the operation, without the dialect prefix.
473   StringRef name;
474 };
475 } // namespace
476 
477 /// Parse a single dialect group encoded in the byte stream.
478 static LogicalResult parseDialectGrouping(
479     EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
480     function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
481   // Parse the dialect and the number of entries in the group.
482   BytecodeDialect *dialect;
483   if (failed(parseEntry(reader, dialects, dialect, "dialect")))
484     return failure();
485   uint64_t numEntries;
486   if (failed(reader.parseVarInt(numEntries)))
487     return failure();
488 
489   for (uint64_t i = 0; i < numEntries; ++i)
490     if (failed(entryCallback(dialect)))
491       return failure();
492   return success();
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // ResourceSectionReader
497 //===----------------------------------------------------------------------===//
498 
499 namespace {
500 /// This class is used to read the resource section from the bytecode.
501 class ResourceSectionReader {
502 public:
503   /// Initialize the resource section reader with the given section data.
504   LogicalResult
505   initialize(Location fileLoc, const ParserConfig &config,
506              MutableArrayRef<BytecodeDialect> dialects,
507              StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
508              ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
509              const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
510 
511   /// Parse a dialect resource handle from the resource section.
512   LogicalResult parseResourceHandle(EncodingReader &reader,
513                                     AsmDialectResourceHandle &result) {
514     return parseEntry(reader, dialectResources, result, "resource handle");
515   }
516 
517 private:
518   /// The table of dialect resources within the bytecode file.
519   SmallVector<AsmDialectResourceHandle> dialectResources;
520   llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
521 };
522 
523 class ParsedResourceEntry : public AsmParsedResourceEntry {
524 public:
525   ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
526                       EncodingReader &reader, StringSectionReader &stringReader,
527                       const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
528       : key(key), kind(kind), reader(reader), stringReader(stringReader),
529         bufferOwnerRef(bufferOwnerRef) {}
530   ~ParsedResourceEntry() override = default;
531 
532   StringRef getKey() const final { return key; }
533 
534   InFlightDiagnostic emitError() const final { return reader.emitError(); }
535 
536   AsmResourceEntryKind getKind() const final { return kind; }
537 
538   FailureOr<bool> parseAsBool() const final {
539     if (kind != AsmResourceEntryKind::Bool)
540       return emitError() << "expected a bool resource entry, but found a "
541                          << toString(kind) << " entry instead";
542 
543     bool value;
544     if (failed(reader.parseByte(value)))
545       return failure();
546     return value;
547   }
548   FailureOr<std::string> parseAsString() const final {
549     if (kind != AsmResourceEntryKind::String)
550       return emitError() << "expected a string resource entry, but found a "
551                          << toString(kind) << " entry instead";
552 
553     StringRef string;
554     if (failed(stringReader.parseString(reader, string)))
555       return failure();
556     return string.str();
557   }
558 
559   FailureOr<AsmResourceBlob>
560   parseAsBlob(BlobAllocatorFn allocator) const final {
561     if (kind != AsmResourceEntryKind::Blob)
562       return emitError() << "expected a blob resource entry, but found a "
563                          << toString(kind) << " entry instead";
564 
565     ArrayRef<uint8_t> data;
566     uint64_t alignment;
567     if (failed(reader.parseBlobAndAlignment(data, alignment)))
568       return failure();
569 
570     // If we have an extendable reference to the buffer owner, we don't need to
571     // allocate a new buffer for the data, and can use the data directly.
572     if (bufferOwnerRef) {
573       ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
574                               data.size());
575 
576       // Allocate an unmanager buffer which captures a reference to the owner.
577       // For now we just mark this as immutable, but in the future we should
578       // explore marking this as mutable when desired.
579       return UnmanagedAsmResourceBlob::allocateWithAlign(
580           charData, alignment,
581           [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
582     }
583 
584     // Allocate memory for the blob using the provided allocator and copy the
585     // data into it.
586     AsmResourceBlob blob = allocator(data.size(), alignment);
587     assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
588            blob.isMutable() &&
589            "blob allocator did not return a properly aligned address");
590     memcpy(blob.getMutableData().data(), data.data(), data.size());
591     return blob;
592   }
593 
594 private:
595   StringRef key;
596   AsmResourceEntryKind kind;
597   EncodingReader &reader;
598   StringSectionReader &stringReader;
599   const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
600 };
601 } // namespace
602 
603 template <typename T>
604 static LogicalResult
605 parseResourceGroup(Location fileLoc, bool allowEmpty,
606                    EncodingReader &offsetReader, EncodingReader &resourceReader,
607                    StringSectionReader &stringReader, T *handler,
608                    const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
609                    function_ref<StringRef(StringRef)> remapKey = {},
610                    function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
611   uint64_t numResources;
612   if (failed(offsetReader.parseVarInt(numResources)))
613     return failure();
614 
615   for (uint64_t i = 0; i < numResources; ++i) {
616     StringRef key;
617     AsmResourceEntryKind kind;
618     uint64_t resourceOffset;
619     ArrayRef<uint8_t> data;
620     if (failed(stringReader.parseString(offsetReader, key)) ||
621         failed(offsetReader.parseVarInt(resourceOffset)) ||
622         failed(offsetReader.parseByte(kind)) ||
623         failed(resourceReader.parseBytes(resourceOffset, data)))
624       return failure();
625 
626     // Process the resource key.
627     if ((processKeyFn && failed(processKeyFn(key))))
628       return failure();
629 
630     // If the resource data is empty and we allow it, don't error out when
631     // parsing below, just skip it.
632     if (allowEmpty && data.empty())
633       continue;
634 
635     // Ignore the entry if we don't have a valid handler.
636     if (!handler)
637       continue;
638 
639     // Otherwise, parse the resource value.
640     EncodingReader entryReader(data, fileLoc);
641     key = remapKey(key);
642     ParsedResourceEntry entry(key, kind, entryReader, stringReader,
643                               bufferOwnerRef);
644     if (failed(handler->parseResource(entry)))
645       return failure();
646     if (!entryReader.empty()) {
647       return entryReader.emitError(
648           "unexpected trailing bytes in resource entry '", key, "'");
649     }
650   }
651   return success();
652 }
653 
654 LogicalResult ResourceSectionReader::initialize(
655     Location fileLoc, const ParserConfig &config,
656     MutableArrayRef<BytecodeDialect> dialects,
657     StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
658     ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
659     const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
660   EncodingReader resourceReader(sectionData, fileLoc);
661   EncodingReader offsetReader(offsetSectionData, fileLoc);
662 
663   // Read the number of external resource providers.
664   uint64_t numExternalResourceGroups;
665   if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
666     return failure();
667 
668   // Utility functor that dispatches to `parseResourceGroup`, but implicitly
669   // provides most of the arguments.
670   auto parseGroup = [&](auto *handler, bool allowEmpty = false,
671                         function_ref<LogicalResult(StringRef)> keyFn = {}) {
672     auto resolveKey = [&](StringRef key) -> StringRef {
673       auto it = dialectResourceHandleRenamingMap.find(key);
674       if (it == dialectResourceHandleRenamingMap.end())
675         return "";
676       return it->second;
677     };
678 
679     return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
680                               stringReader, handler, bufferOwnerRef, resolveKey,
681                               keyFn);
682   };
683 
684   // Read the external resources from the bytecode.
685   for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
686     StringRef key;
687     if (failed(stringReader.parseString(offsetReader, key)))
688       return failure();
689 
690     // Get the handler for these resources.
691     // TODO: Should we require handling external resources in some scenarios?
692     AsmResourceParser *handler = config.getResourceParser(key);
693     if (!handler) {
694       emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
695                            << "'";
696     }
697 
698     if (failed(parseGroup(handler)))
699       return failure();
700   }
701 
702   // Read the dialect resources from the bytecode.
703   MLIRContext *ctx = fileLoc->getContext();
704   while (!offsetReader.empty()) {
705     BytecodeDialect *dialect;
706     if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
707         failed(dialect->load(dialectReader, ctx)))
708       return failure();
709     Dialect *loadedDialect = dialect->getLoadedDialect();
710     if (!loadedDialect) {
711       return resourceReader.emitError()
712              << "dialect '" << dialect->name << "' is unknown";
713     }
714     const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
715     if (!handler) {
716       return resourceReader.emitError()
717              << "unexpected resources for dialect '" << dialect->name << "'";
718     }
719 
720     // Ensure that each resource is declared before being processed.
721     auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
722       FailureOr<AsmDialectResourceHandle> handle =
723           handler->declareResource(key);
724       if (failed(handle)) {
725         return resourceReader.emitError()
726                << "unknown 'resource' key '" << key << "' for dialect '"
727                << dialect->name << "'";
728       }
729       dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
730       dialectResources.push_back(*handle);
731       return success();
732     };
733 
734     // Parse the resources for this dialect. We allow empty resources because we
735     // just treat these as declarations.
736     if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
737       return failure();
738   }
739 
740   return success();
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // Attribute/Type Reader
745 //===----------------------------------------------------------------------===//
746 
747 namespace {
748 /// This class provides support for reading attribute and type entries from the
749 /// bytecode. Attribute and Type entries are read lazily on demand, so we use
750 /// this reader to manage when to actually parse them from the bytecode.
751 class AttrTypeReader {
752   /// This class represents a single attribute or type entry.
753   template <typename T>
754   struct Entry {
755     /// The entry, or null if it hasn't been resolved yet.
756     T entry = {};
757     /// The parent dialect of this entry.
758     BytecodeDialect *dialect = nullptr;
759     /// A flag indicating if the entry was encoded using a custom encoding,
760     /// instead of using the textual assembly format.
761     bool hasCustomEncoding = false;
762     /// The raw data of this entry in the bytecode.
763     ArrayRef<uint8_t> data;
764   };
765   using AttrEntry = Entry<Attribute>;
766   using TypeEntry = Entry<Type>;
767 
768 public:
769   AttrTypeReader(StringSectionReader &stringReader,
770                  ResourceSectionReader &resourceReader, Location fileLoc)
771       : stringReader(stringReader), resourceReader(resourceReader),
772         fileLoc(fileLoc) {}
773 
774   /// Initialize the attribute and type information within the reader.
775   LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
776                            ArrayRef<uint8_t> sectionData,
777                            ArrayRef<uint8_t> offsetSectionData);
778 
779   /// Resolve the attribute or type at the given index. Returns nullptr on
780   /// failure.
781   Attribute resolveAttribute(size_t index) {
782     return resolveEntry(attributes, index, "Attribute");
783   }
784   Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
785 
786   /// Parse a reference to an attribute or type using the given reader.
787   LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
788     uint64_t attrIdx;
789     if (failed(reader.parseVarInt(attrIdx)))
790       return failure();
791     result = resolveAttribute(attrIdx);
792     return success(!!result);
793   }
794   LogicalResult parseType(EncodingReader &reader, Type &result) {
795     uint64_t typeIdx;
796     if (failed(reader.parseVarInt(typeIdx)))
797       return failure();
798     result = resolveType(typeIdx);
799     return success(!!result);
800   }
801 
802   template <typename T>
803   LogicalResult parseAttribute(EncodingReader &reader, T &result) {
804     Attribute baseResult;
805     if (failed(parseAttribute(reader, baseResult)))
806       return failure();
807     if ((result = dyn_cast<T>(baseResult)))
808       return success();
809     return reader.emitError("expected attribute of type: ",
810                             llvm::getTypeName<T>(), ", but got: ", baseResult);
811   }
812 
813 private:
814   /// Resolve the given entry at `index`.
815   template <typename T>
816   T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
817                  StringRef entryType);
818 
819   /// Parse an entry using the given reader that was encoded using the textual
820   /// assembly format.
821   template <typename T>
822   LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
823                               StringRef entryType);
824 
825   /// Parse an entry using the given reader that was encoded using a custom
826   /// bytecode format.
827   template <typename T>
828   LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
829                                  StringRef entryType);
830 
831   /// The string section reader used to resolve string references when parsing
832   /// custom encoded attribute/type entries.
833   StringSectionReader &stringReader;
834 
835   /// The resource section reader used to resolve resource references when
836   /// parsing custom encoded attribute/type entries.
837   ResourceSectionReader &resourceReader;
838 
839   /// The set of attribute and type entries.
840   SmallVector<AttrEntry> attributes;
841   SmallVector<TypeEntry> types;
842 
843   /// A location used for error emission.
844   Location fileLoc;
845 };
846 
847 class DialectReader : public DialectBytecodeReader {
848 public:
849   DialectReader(AttrTypeReader &attrTypeReader,
850                 StringSectionReader &stringReader,
851                 ResourceSectionReader &resourceReader, EncodingReader &reader)
852       : attrTypeReader(attrTypeReader), stringReader(stringReader),
853         resourceReader(resourceReader), reader(reader) {}
854 
855   InFlightDiagnostic emitError(const Twine &msg) override {
856     return reader.emitError(msg);
857   }
858 
859   DialectReader withEncodingReader(EncodingReader &encReader) {
860     return DialectReader(attrTypeReader, stringReader, resourceReader,
861                          encReader);
862   }
863 
864   Location getLoc() const { return reader.getLoc(); }
865 
866   //===--------------------------------------------------------------------===//
867   // IR
868   //===--------------------------------------------------------------------===//
869 
870   LogicalResult readAttribute(Attribute &result) override {
871     return attrTypeReader.parseAttribute(reader, result);
872   }
873 
874   LogicalResult readType(Type &result) override {
875     return attrTypeReader.parseType(reader, result);
876   }
877 
878   FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
879     AsmDialectResourceHandle handle;
880     if (failed(resourceReader.parseResourceHandle(reader, handle)))
881       return failure();
882     return handle;
883   }
884 
885   //===--------------------------------------------------------------------===//
886   // Primitives
887   //===--------------------------------------------------------------------===//
888 
889   LogicalResult readVarInt(uint64_t &result) override {
890     return reader.parseVarInt(result);
891   }
892 
893   LogicalResult readSignedVarInt(int64_t &result) override {
894     uint64_t unsignedResult;
895     if (failed(reader.parseSignedVarInt(unsignedResult)))
896       return failure();
897     result = static_cast<int64_t>(unsignedResult);
898     return success();
899   }
900 
901   FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
902     // Small values are encoded using a single byte.
903     if (bitWidth <= 8) {
904       uint8_t value;
905       if (failed(reader.parseByte(value)))
906         return failure();
907       return APInt(bitWidth, value);
908     }
909 
910     // Large values up to 64 bits are encoded using a single varint.
911     if (bitWidth <= 64) {
912       uint64_t value;
913       if (failed(reader.parseSignedVarInt(value)))
914         return failure();
915       return APInt(bitWidth, value);
916     }
917 
918     // Otherwise, for really big values we encode the array of active words in
919     // the value.
920     uint64_t numActiveWords;
921     if (failed(reader.parseVarInt(numActiveWords)))
922       return failure();
923     SmallVector<uint64_t, 4> words(numActiveWords);
924     for (uint64_t i = 0; i < numActiveWords; ++i)
925       if (failed(reader.parseSignedVarInt(words[i])))
926         return failure();
927     return APInt(bitWidth, words);
928   }
929 
930   FailureOr<APFloat>
931   readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
932     FailureOr<APInt> intVal =
933         readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
934     if (failed(intVal))
935       return failure();
936     return APFloat(semantics, *intVal);
937   }
938 
939   LogicalResult readString(StringRef &result) override {
940     return stringReader.parseString(reader, result);
941   }
942 
943   LogicalResult readBlob(ArrayRef<char> &result) override {
944     uint64_t dataSize;
945     ArrayRef<uint8_t> data;
946     if (failed(reader.parseVarInt(dataSize)) ||
947         failed(reader.parseBytes(dataSize, data)))
948       return failure();
949     result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
950                             data.size());
951     return success();
952   }
953 
954 private:
955   AttrTypeReader &attrTypeReader;
956   StringSectionReader &stringReader;
957   ResourceSectionReader &resourceReader;
958   EncodingReader &reader;
959 };
960 } // namespace
961 
962 LogicalResult
963 AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
964                            ArrayRef<uint8_t> sectionData,
965                            ArrayRef<uint8_t> offsetSectionData) {
966   EncodingReader offsetReader(offsetSectionData, fileLoc);
967 
968   // Parse the number of attribute and type entries.
969   uint64_t numAttributes, numTypes;
970   if (failed(offsetReader.parseVarInt(numAttributes)) ||
971       failed(offsetReader.parseVarInt(numTypes)))
972     return failure();
973   attributes.resize(numAttributes);
974   types.resize(numTypes);
975 
976   // A functor used to accumulate the offsets for the entries in the given
977   // range.
978   uint64_t currentOffset = 0;
979   auto parseEntries = [&](auto &&range) {
980     size_t currentIndex = 0, endIndex = range.size();
981 
982     // Parse an individual entry.
983     auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
984       auto &entry = range[currentIndex++];
985 
986       uint64_t entrySize;
987       if (failed(offsetReader.parseVarIntWithFlag(entrySize,
988                                                   entry.hasCustomEncoding)))
989         return failure();
990 
991       // Verify that the offset is actually valid.
992       if (currentOffset + entrySize > sectionData.size()) {
993         return offsetReader.emitError(
994             "Attribute or Type entry offset points past the end of section");
995       }
996 
997       entry.data = sectionData.slice(currentOffset, entrySize);
998       entry.dialect = dialect;
999       currentOffset += entrySize;
1000       return success();
1001     };
1002     while (currentIndex != endIndex)
1003       if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1004         return failure();
1005     return success();
1006   };
1007 
1008   // Process each of the attributes, and then the types.
1009   if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
1010     return failure();
1011 
1012   // Ensure that we read everything from the section.
1013   if (!offsetReader.empty()) {
1014     return offsetReader.emitError(
1015         "unexpected trailing data in the Attribute/Type offset section");
1016   }
1017   return success();
1018 }
1019 
1020 template <typename T>
1021 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1022                                StringRef entryType) {
1023   if (index >= entries.size()) {
1024     emitError(fileLoc) << "invalid " << entryType << " index: " << index;
1025     return {};
1026   }
1027 
1028   // If the entry has already been resolved, there is nothing left to do.
1029   Entry<T> &entry = entries[index];
1030   if (entry.entry)
1031     return entry.entry;
1032 
1033   // Parse the entry.
1034   EncodingReader reader(entry.data, fileLoc);
1035 
1036   // Parse based on how the entry was encoded.
1037   if (entry.hasCustomEncoding) {
1038     if (failed(parseCustomEntry(entry, reader, entryType)))
1039       return T();
1040   } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1041     return T();
1042   }
1043 
1044   if (!reader.empty()) {
1045     reader.emitError("unexpected trailing bytes after " + entryType + " entry");
1046     return T();
1047   }
1048   return entry.entry;
1049 }
1050 
1051 template <typename T>
1052 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1053                                             StringRef entryType) {
1054   StringRef asmStr;
1055   if (failed(reader.parseNullTerminatedString(asmStr)))
1056     return failure();
1057 
1058   // Invoke the MLIR assembly parser to parse the entry text.
1059   size_t numRead = 0;
1060   MLIRContext *context = fileLoc->getContext();
1061   if constexpr (std::is_same_v<T, Type>)
1062     result =
1063         ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1064   else
1065     result = ::parseAttribute(asmStr, context, Type(), &numRead,
1066                               /*isKnownNullTerminated=*/true);
1067   if (!result)
1068     return failure();
1069 
1070   // Ensure there weren't dangling characters after the entry.
1071   if (numRead != asmStr.size()) {
1072     return reader.emitError("trailing characters found after ", entryType,
1073                             " assembly format: ", asmStr.drop_front(numRead));
1074   }
1075   return success();
1076 }
1077 
1078 template <typename T>
1079 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1080                                                EncodingReader &reader,
1081                                                StringRef entryType) {
1082   DialectReader dialectReader(*this, stringReader, resourceReader, reader);
1083   if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1084     return failure();
1085   // Ensure that the dialect implements the bytecode interface.
1086   if (!entry.dialect->interface) {
1087     return reader.emitError("dialect '", entry.dialect->name,
1088                             "' does not implement the bytecode interface");
1089   }
1090 
1091   // Ask the dialect to parse the entry. If the dialect is versioned, parse
1092   // using the versioned encoding readers.
1093   if (entry.dialect->loadedVersion.get()) {
1094     if constexpr (std::is_same_v<T, Type>)
1095       entry.entry = entry.dialect->interface->readType(
1096           dialectReader, *entry.dialect->loadedVersion);
1097     else
1098       entry.entry = entry.dialect->interface->readAttribute(
1099           dialectReader, *entry.dialect->loadedVersion);
1100 
1101   } else {
1102     if constexpr (std::is_same_v<T, Type>)
1103       entry.entry = entry.dialect->interface->readType(dialectReader);
1104     else
1105       entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1106   }
1107   return success(!!entry.entry);
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // Bytecode Reader
1112 //===----------------------------------------------------------------------===//
1113 
1114 /// This class is used to read a bytecode buffer and translate it into MLIR.
1115 class mlir::BytecodeReader::Impl {
1116   struct RegionReadState;
1117   using LazyLoadableOpsInfo =
1118       std::list<std::pair<Operation *, RegionReadState>>;
1119   using LazyLoadableOpsMap =
1120       DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
1121 
1122 public:
1123   Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1124        llvm::MemoryBufferRef buffer,
1125        const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1126       : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1127         attrTypeReader(stringReader, resourceReader, fileLoc),
1128         // Use the builtin unrealized conversion cast operation to represent
1129         // forward references to values that aren't yet defined.
1130         forwardRefOpState(UnknownLoc::get(config.getContext()),
1131                           "builtin.unrealized_conversion_cast", ValueRange(),
1132                           NoneType::get(config.getContext())),
1133         buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1134 
1135   /// Read the bytecode defined within `buffer` into the given block.
1136   LogicalResult read(Block *block,
1137                      llvm::function_ref<bool(Operation *)> lazyOps);
1138 
1139   /// Return the number of ops that haven't been materialized yet.
1140   int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1141 
1142   bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
1143 
1144   /// Materialize the provided operation, invoke the lazyOpsCallback on every
1145   /// newly found lazy operation.
1146   LogicalResult
1147   materialize(Operation *op,
1148               llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1149     this->lazyOpsCallback = lazyOpsCallback;
1150     auto resetlazyOpsCallback =
1151         llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1152     auto it = lazyLoadableOpsMap.find(op);
1153     assert(it != lazyLoadableOpsMap.end() &&
1154            "materialize called on non-materializable op");
1155     return materialize(it);
1156   }
1157 
1158   /// Materialize all operations.
1159   LogicalResult materializeAll() {
1160     while (!lazyLoadableOpsMap.empty()) {
1161       if (failed(materialize(lazyLoadableOpsMap.begin())))
1162         return failure();
1163     }
1164     return success();
1165   }
1166 
1167   /// Finalize the lazy-loading by calling back with every op that hasn't been
1168   /// materialized to let the client decide if the op should be deleted or
1169   /// materialized. The op is materialized if the callback returns true, deleted
1170   /// otherwise.
1171   LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1172     while (!lazyLoadableOps.empty()) {
1173       Operation *op = lazyLoadableOps.begin()->first;
1174       if (shouldMaterialize(op)) {
1175         if (failed(materialize(lazyLoadableOpsMap.find(op))))
1176           return failure();
1177         continue;
1178       }
1179       op->dropAllReferences();
1180       op->erase();
1181       lazyLoadableOps.pop_front();
1182       lazyLoadableOpsMap.erase(op);
1183     }
1184     return success();
1185   }
1186 
1187 private:
1188   LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1189     assert(it != lazyLoadableOpsMap.end() &&
1190            "materialize called on non-materializable op");
1191     valueScopes.emplace_back();
1192     std::vector<RegionReadState> regionStack;
1193     regionStack.push_back(std::move(it->getSecond()->second));
1194     lazyLoadableOps.erase(it->getSecond());
1195     lazyLoadableOpsMap.erase(it);
1196     auto result = parseRegions(regionStack, regionStack.back());
1197     assert(regionStack.empty());
1198     return result;
1199   }
1200 
1201   /// Return the context for this config.
1202   MLIRContext *getContext() const { return config.getContext(); }
1203 
1204   /// Parse the bytecode version.
1205   LogicalResult parseVersion(EncodingReader &reader);
1206 
1207   //===--------------------------------------------------------------------===//
1208   // Dialect Section
1209 
1210   LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1211 
1212   /// Parse an operation name reference using the given reader.
1213   FailureOr<OperationName> parseOpName(EncodingReader &reader);
1214 
1215   //===--------------------------------------------------------------------===//
1216   // Attribute/Type Section
1217 
1218   /// Parse an attribute or type using the given reader.
1219   template <typename T>
1220   LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1221     return attrTypeReader.parseAttribute(reader, result);
1222   }
1223   LogicalResult parseType(EncodingReader &reader, Type &result) {
1224     return attrTypeReader.parseType(reader, result);
1225   }
1226 
1227   //===--------------------------------------------------------------------===//
1228   // Resource Section
1229 
1230   LogicalResult
1231   parseResourceSection(EncodingReader &reader,
1232                        std::optional<ArrayRef<uint8_t>> resourceData,
1233                        std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1234 
1235   //===--------------------------------------------------------------------===//
1236   // IR Section
1237 
1238   /// This struct represents the current read state of a range of regions. This
1239   /// struct is used to enable iterative parsing of regions.
1240   struct RegionReadState {
1241     RegionReadState(Operation *op, EncodingReader *reader,
1242                     bool isIsolatedFromAbove)
1243         : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1244     RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1245                     bool isIsolatedFromAbove)
1246         : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1247           isIsolatedFromAbove(isIsolatedFromAbove) {}
1248 
1249     /// The current regions being read.
1250     MutableArrayRef<Region>::iterator curRegion, endRegion;
1251     /// This is the reader to use for this region, this pointer is pointing to
1252     /// the parent region reader unless the current region is IsolatedFromAbove,
1253     /// in which case the pointer is pointing to the `owningReader` which is a
1254     /// section dedicated to the current region.
1255     EncodingReader *reader;
1256     std::unique_ptr<EncodingReader> owningReader;
1257 
1258     /// The number of values defined immediately within this region.
1259     unsigned numValues = 0;
1260 
1261     /// The current blocks of the region being read.
1262     SmallVector<Block *> curBlocks;
1263     Region::iterator curBlock = {};
1264 
1265     /// The number of operations remaining to be read from the current block
1266     /// being read.
1267     uint64_t numOpsRemaining = 0;
1268 
1269     /// A flag indicating if the regions being read are isolated from above.
1270     bool isIsolatedFromAbove = false;
1271   };
1272 
1273   LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1274   LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1275                              RegionReadState &readState);
1276   FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1277                                                RegionReadState &readState,
1278                                                bool &isIsolatedFromAbove);
1279 
1280   LogicalResult parseRegion(RegionReadState &readState);
1281   LogicalResult parseBlockHeader(EncodingReader &reader,
1282                                  RegionReadState &readState);
1283   LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1284 
1285   //===--------------------------------------------------------------------===//
1286   // Value Processing
1287 
1288   /// Parse an operand reference using the given reader. Returns nullptr in the
1289   /// case of failure.
1290   Value parseOperand(EncodingReader &reader);
1291 
1292   /// Sequentially define the given value range.
1293   LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1294 
1295   /// Create a value to use for a forward reference.
1296   Value createForwardRef();
1297 
1298   //===--------------------------------------------------------------------===//
1299   // Use-list order helpers
1300 
1301   /// This struct is a simple storage that contains information required to
1302   /// reorder the use-list of a value with respect to the pre-order traversal
1303   /// ordering.
1304   struct UseListOrderStorage {
1305     UseListOrderStorage(bool isIndexPairEncoding,
1306                         SmallVector<unsigned, 4> &&indices)
1307         : indices(std::move(indices)),
1308           isIndexPairEncoding(isIndexPairEncoding){};
1309     /// The vector containing the information required to reorder the
1310     /// use-list of a value.
1311     SmallVector<unsigned, 4> indices;
1312 
1313     /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1314     /// indexing, such as `dst = order[src]`.
1315     bool isIndexPairEncoding;
1316   };
1317 
1318   /// Parse use-list order from bytecode for a range of values if available. The
1319   /// range is expected to be either a block argument or an op result range. On
1320   /// success, return a map of the position in the range and the use-list order
1321   /// encoding. The function assumes to know the size of the range it is
1322   /// processing.
1323   using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1324   FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1325                                                    uint64_t rangeSize);
1326 
1327   /// Shuffle the use-chain according to the order parsed.
1328   LogicalResult sortUseListOrder(Value value);
1329 
1330   /// Recursively visit all the values defined within topLevelOp and sort the
1331   /// use-list orders according to the indices parsed.
1332   LogicalResult processUseLists(Operation *topLevelOp);
1333 
1334   //===--------------------------------------------------------------------===//
1335   // Fields
1336 
1337   /// This class represents a single value scope, in which a value scope is
1338   /// delimited by isolated from above regions.
1339   struct ValueScope {
1340     /// Push a new region state onto this scope, reserving enough values for
1341     /// those defined within the current region of the provided state.
1342     void push(RegionReadState &readState) {
1343       nextValueIDs.push_back(values.size());
1344       values.resize(values.size() + readState.numValues);
1345     }
1346 
1347     /// Pop the values defined for the current region within the provided region
1348     /// state.
1349     void pop(RegionReadState &readState) {
1350       values.resize(values.size() - readState.numValues);
1351       nextValueIDs.pop_back();
1352     }
1353 
1354     /// The set of values defined in this scope.
1355     std::vector<Value> values;
1356 
1357     /// The ID for the next defined value for each region current being
1358     /// processed in this scope.
1359     SmallVector<unsigned, 4> nextValueIDs;
1360   };
1361 
1362   /// The configuration of the parser.
1363   const ParserConfig &config;
1364 
1365   /// A location to use when emitting errors.
1366   Location fileLoc;
1367 
1368   /// Flag that indicates if lazyloading is enabled.
1369   bool lazyLoading;
1370 
1371   /// Keep track of operations that have been lazy loaded (their regions haven't
1372   /// been materialized), along with the `RegionReadState` that allows to
1373   /// lazy-load the regions nested under the operation.
1374   LazyLoadableOpsInfo lazyLoadableOps;
1375   LazyLoadableOpsMap lazyLoadableOpsMap;
1376   llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1377 
1378   /// The reader used to process attribute and types within the bytecode.
1379   AttrTypeReader attrTypeReader;
1380 
1381   /// The version of the bytecode being read.
1382   uint64_t version = 0;
1383 
1384   /// The producer of the bytecode being read.
1385   StringRef producer;
1386 
1387   /// The table of IR units referenced within the bytecode file.
1388   SmallVector<BytecodeDialect> dialects;
1389   SmallVector<BytecodeOperationName> opNames;
1390 
1391   /// The reader used to process resources within the bytecode.
1392   ResourceSectionReader resourceReader;
1393 
1394   /// Worklist of values with custom use-list orders to process before the end
1395   /// of the parsing.
1396   DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1397 
1398   /// The table of strings referenced within the bytecode file.
1399   StringSectionReader stringReader;
1400 
1401   /// The current set of available IR value scopes.
1402   std::vector<ValueScope> valueScopes;
1403 
1404   /// The global pre-order operation ordering.
1405   DenseMap<Operation *, unsigned> operationIDs;
1406 
1407   /// A block containing the set of operations defined to create forward
1408   /// references.
1409   Block forwardRefOps;
1410 
1411   /// A block containing previously created, and no longer used, forward
1412   /// reference operations.
1413   Block openForwardRefOps;
1414 
1415   /// An operation state used when instantiating forward references.
1416   OperationState forwardRefOpState;
1417 
1418   /// Reference to the input buffer.
1419   llvm::MemoryBufferRef buffer;
1420 
1421   /// The optional owning source manager, which when present may be used to
1422   /// extend the lifetime of the input buffer.
1423   const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1424 };
1425 
1426 LogicalResult BytecodeReader::Impl::read(
1427     Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1428   EncodingReader reader(buffer.getBuffer(), fileLoc);
1429   this->lazyOpsCallback = lazyOpsCallback;
1430   auto resetlazyOpsCallback =
1431       llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1432 
1433   // Skip over the bytecode header, this should have already been checked.
1434   if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
1435     return failure();
1436   // Parse the bytecode version and producer.
1437   if (failed(parseVersion(reader)) ||
1438       failed(reader.parseNullTerminatedString(producer)))
1439     return failure();
1440 
1441   // Add a diagnostic handler that attaches a note that includes the original
1442   // producer of the bytecode.
1443   ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1444     diag.attachNote() << "in bytecode version " << version
1445                       << " produced by: " << producer;
1446     return failure();
1447   });
1448 
1449   // Parse the raw data for each of the top-level sections of the bytecode.
1450   std::optional<ArrayRef<uint8_t>>
1451       sectionDatas[bytecode::Section::kNumSections];
1452   while (!reader.empty()) {
1453     // Read the next section from the bytecode.
1454     bytecode::Section::ID sectionID;
1455     ArrayRef<uint8_t> sectionData;
1456     if (failed(reader.parseSection(sectionID, sectionData)))
1457       return failure();
1458 
1459     // Check for duplicate sections, we only expect one instance of each.
1460     if (sectionDatas[sectionID]) {
1461       return reader.emitError("duplicate top-level section: ",
1462                               ::toString(sectionID));
1463     }
1464     sectionDatas[sectionID] = sectionData;
1465   }
1466   // Check that all of the required sections were found.
1467   for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1468     bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1469     if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
1470       return reader.emitError("missing data for top-level section: ",
1471                               ::toString(sectionID));
1472     }
1473   }
1474 
1475   // Process the string section first.
1476   if (failed(stringReader.initialize(
1477           fileLoc, *sectionDatas[bytecode::Section::kString])))
1478     return failure();
1479 
1480   // Process the dialect section.
1481   if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
1482     return failure();
1483 
1484   // Process the resource section if present.
1485   if (failed(parseResourceSection(
1486           reader, sectionDatas[bytecode::Section::kResource],
1487           sectionDatas[bytecode::Section::kResourceOffset])))
1488     return failure();
1489 
1490   // Process the attribute and type section.
1491   if (failed(attrTypeReader.initialize(
1492           dialects, *sectionDatas[bytecode::Section::kAttrType],
1493           *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1494     return failure();
1495 
1496   // Finally, process the IR section.
1497   return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
1498 }
1499 
1500 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1501   if (failed(reader.parseVarInt(version)))
1502     return failure();
1503 
1504   // Validate the bytecode version.
1505   uint64_t currentVersion = bytecode::kVersion;
1506   uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1507   if (version < minSupportedVersion) {
1508     return reader.emitError("bytecode version ", version,
1509                             " is older than the current version of ",
1510                             currentVersion, ", and upgrade is not supported");
1511   }
1512   if (version > currentVersion) {
1513     return reader.emitError("bytecode version ", version,
1514                             " is newer than the current version ",
1515                             currentVersion);
1516   }
1517   // Override any request to lazy-load if the bytecode version is too old.
1518   if (version < 2)
1519     lazyLoading = false;
1520   return success();
1521 }
1522 
1523 //===----------------------------------------------------------------------===//
1524 // Dialect Section
1525 
1526 LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
1527   if (dialect)
1528     return success();
1529   Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1530   if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1531     return reader.emitError("dialect '")
1532            << name
1533            << "' is unknown. If this is intended, please call "
1534               "allowUnregisteredDialects() on the MLIRContext, or use "
1535               "-allow-unregistered-dialect with the MLIR tool used.";
1536   }
1537   dialect = loadedDialect;
1538 
1539   // If the dialect was actually loaded, check to see if it has a bytecode
1540   // interface.
1541   if (loadedDialect)
1542     interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
1543   if (!versionBuffer.empty()) {
1544     if (!interface)
1545       return reader.emitError("dialect '")
1546              << name
1547              << "' does not implement the bytecode interface, "
1548                 "but found a version entry";
1549     EncodingReader encReader(versionBuffer, reader.getLoc());
1550     DialectReader versionReader = reader.withEncodingReader(encReader);
1551     loadedVersion = interface->readVersion(versionReader);
1552     if (!loadedVersion)
1553       return failure();
1554   }
1555   return success();
1556 }
1557 
1558 LogicalResult
1559 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1560   EncodingReader sectionReader(sectionData, fileLoc);
1561 
1562   // Parse the number of dialects in the section.
1563   uint64_t numDialects;
1564   if (failed(sectionReader.parseVarInt(numDialects)))
1565     return failure();
1566   dialects.resize(numDialects);
1567 
1568   // Parse each of the dialects.
1569   for (uint64_t i = 0; i < numDialects; ++i) {
1570     /// Before version 1, there wasn't any versioning available for dialects,
1571     /// and the entryIdx represent the string itself.
1572     if (version == 0) {
1573       if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
1574         return failure();
1575       continue;
1576     }
1577     // Parse ID representing dialect and version.
1578     uint64_t dialectNameIdx;
1579     bool versionAvailable;
1580     if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
1581                                                  versionAvailable)))
1582       return failure();
1583     if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
1584                                                dialects[i].name)))
1585       return failure();
1586     if (versionAvailable) {
1587       bytecode::Section::ID sectionID;
1588       if (failed(
1589               sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
1590         return failure();
1591       if (sectionID != bytecode::Section::kDialectVersions) {
1592         emitError(fileLoc, "expected dialect version section");
1593         return failure();
1594       }
1595     }
1596   }
1597 
1598   // Parse the operation names, which are grouped by dialect.
1599   auto parseOpName = [&](BytecodeDialect *dialect) {
1600     StringRef opName;
1601     if (failed(stringReader.parseString(sectionReader, opName)))
1602       return failure();
1603     opNames.emplace_back(dialect, opName);
1604     return success();
1605   };
1606   // Avoid re-allocation in bytecode version > 3 where the number of ops are
1607   // known.
1608   if (version > 3) {
1609     uint64_t numOps;
1610     if (failed(sectionReader.parseVarInt(numOps)))
1611       return failure();
1612     opNames.reserve(numOps);
1613   }
1614   while (!sectionReader.empty())
1615     if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
1616       return failure();
1617   return success();
1618 }
1619 
1620 FailureOr<OperationName>
1621 BytecodeReader::Impl::parseOpName(EncodingReader &reader) {
1622   BytecodeOperationName *opName = nullptr;
1623   if (failed(parseEntry(reader, opNames, opName, "operation name")))
1624     return failure();
1625 
1626   // Check to see if this operation name has already been resolved. If we
1627   // haven't, load the dialect and build the operation name.
1628   if (!opName->opName) {
1629     // Load the dialect and its version.
1630     DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1631                                 reader);
1632     if (failed(opName->dialect->load(dialectReader, getContext())))
1633       return failure();
1634     // If the opName is empty, this is because we use to accept names such as
1635     // `foo` without any `.` separator. We shouldn't tolerate this in textual
1636     // format anymore but for now we'll be backward compatible. This can only
1637     // happen with unregistered dialects.
1638     if (opName->name.empty()) {
1639       if (opName->dialect->getLoadedDialect())
1640         return emitError(fileLoc) << "has an empty opname for dialect '"
1641                                   << opName->dialect->name << "'\n";
1642 
1643       opName->opName.emplace(opName->dialect->name, getContext());
1644     } else {
1645       opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
1646                              getContext());
1647     }
1648   }
1649   return *opName->opName;
1650 }
1651 
1652 //===----------------------------------------------------------------------===//
1653 // Resource Section
1654 
1655 LogicalResult BytecodeReader::Impl::parseResourceSection(
1656     EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
1657     std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
1658   // Ensure both sections are either present or not.
1659   if (resourceData.has_value() != resourceOffsetData.has_value()) {
1660     if (resourceOffsetData)
1661       return emitError(fileLoc, "unexpected resource offset section when "
1662                                 "resource section is not present");
1663     return emitError(
1664         fileLoc,
1665         "expected resource offset section when resource section is present");
1666   }
1667 
1668   // If the resource sections are absent, there is nothing to do.
1669   if (!resourceData)
1670     return success();
1671 
1672   // Initialize the resource reader with the resource sections.
1673   DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1674                               reader);
1675   return resourceReader.initialize(fileLoc, config, dialects, stringReader,
1676                                    *resourceData, *resourceOffsetData,
1677                                    dialectReader, bufferOwnerRef);
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // UseListOrder Helpers
1682 
1683 FailureOr<BytecodeReader::Impl::UseListMapT>
1684 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
1685                                                 uint64_t numResults) {
1686   BytecodeReader::Impl::UseListMapT map;
1687   uint64_t numValuesToRead = 1;
1688   if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
1689     return failure();
1690 
1691   for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
1692     uint64_t resultIdx = 0;
1693     if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
1694       return failure();
1695 
1696     uint64_t numValues;
1697     bool indexPairEncoding;
1698     if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
1699       return failure();
1700 
1701     SmallVector<unsigned, 4> useListOrders;
1702     for (size_t idx = 0; idx < numValues; idx++) {
1703       uint64_t index;
1704       if (failed(reader.parseVarInt(index)))
1705         return failure();
1706       useListOrders.push_back(index);
1707     }
1708 
1709     // Store in a map the result index
1710     map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
1711                                                    std::move(useListOrders)));
1712   }
1713 
1714   return map;
1715 }
1716 
1717 /// Sorts each use according to the order specified in the use-list parsed. If
1718 /// the custom use-list is not found, this means that the order needs to be
1719 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie
1720 /// on the same operation, the order will follow the reverse operand number
1721 /// ordering.
1722 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
1723   // Early return for trivial use-lists.
1724   if (value.use_empty() || value.hasOneUse())
1725     return success();
1726 
1727   bool hasIncomingOrder =
1728       valueToUseListMap.contains(value.getAsOpaquePointer());
1729 
1730   // Compute the current order of the use-list with respect to the global
1731   // ordering. Detect if the order is already sorted while doing so.
1732   bool alreadySorted = true;
1733   auto &firstUse = *value.use_begin();
1734   uint64_t prevID =
1735       bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
1736   llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
1737   for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
1738     uint64_t currentID = bytecode::getUseID(
1739         item.value(), operationIDs.at(item.value().getOwner()));
1740     alreadySorted &= prevID > currentID;
1741     currentOrder.push_back({item.index(), currentID});
1742     prevID = currentID;
1743   }
1744 
1745   // If the order is already sorted, and there wasn't a custom order to apply
1746   // from the bytecode file, we are done.
1747   if (alreadySorted && !hasIncomingOrder)
1748     return success();
1749 
1750   // If not already sorted, sort the indices of the current order by descending
1751   // useIDs.
1752   if (!alreadySorted)
1753     std::sort(
1754         currentOrder.begin(), currentOrder.end(),
1755         [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
1756 
1757   if (!hasIncomingOrder) {
1758     // If the bytecode file did not contain any custom use-list order, it means
1759     // that the order was descending useID. Hence, shuffle by the first index
1760     // of the `currentOrder` pair.
1761     SmallVector<unsigned> shuffle = SmallVector<unsigned>(
1762         llvm::map_range(currentOrder, [&](auto item) { return item.first; }));
1763     value.shuffleUseList(shuffle);
1764     return success();
1765   }
1766 
1767   // Pull the custom order info from the map.
1768   UseListOrderStorage customOrder =
1769       valueToUseListMap.at(value.getAsOpaquePointer());
1770   SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
1771   uint64_t numUses =
1772       std::distance(value.getUses().begin(), value.getUses().end());
1773 
1774   // If the encoding was a pair of indices `(src, dst)` for every permutation,
1775   // reconstruct the shuffle vector for every use. Initialize the shuffle vector
1776   // as identity, and then apply the mapping encoded in the indices.
1777   if (customOrder.isIndexPairEncoding) {
1778     // Return failure if the number of indices was not representing pairs.
1779     if (shuffle.size() & 1)
1780       return failure();
1781 
1782     SmallVector<unsigned, 4> newShuffle(numUses);
1783     size_t idx = 0;
1784     std::iota(newShuffle.begin(), newShuffle.end(), idx);
1785     for (idx = 0; idx < shuffle.size(); idx += 2)
1786       newShuffle[shuffle[idx]] = shuffle[idx + 1];
1787 
1788     shuffle = std::move(newShuffle);
1789   }
1790 
1791   // Make sure that the indices represent a valid mapping. That is, the sum of
1792   // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
1793   // duplicates are allowed in the list.
1794   DenseSet<unsigned> set;
1795   uint64_t accumulator = 0;
1796   for (const auto &elem : shuffle) {
1797     if (set.contains(elem))
1798       return failure();
1799     accumulator += elem;
1800     set.insert(elem);
1801   }
1802   if (numUses != shuffle.size() ||
1803       accumulator != (((numUses - 1) * numUses) >> 1))
1804     return failure();
1805 
1806   // Apply the current ordering map onto the shuffle vector to get the final
1807   // use-list sorting indices before shuffling.
1808   shuffle = SmallVector<unsigned, 4>(llvm::map_range(
1809       currentOrder, [&](auto item) { return shuffle[item.first]; }));
1810   value.shuffleUseList(shuffle);
1811   return success();
1812 }
1813 
1814 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
1815   // Precompute operation IDs according to the pre-order walk of the IR. We
1816   // can't do this while parsing since parseRegions ordering is not strictly
1817   // equal to the pre-order walk.
1818   unsigned operationID = 0;
1819   topLevelOp->walk<mlir::WalkOrder::PreOrder>(
1820       [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
1821 
1822   auto blockWalk = topLevelOp->walk([this](Block *block) {
1823     for (auto arg : block->getArguments())
1824       if (failed(sortUseListOrder(arg)))
1825         return WalkResult::interrupt();
1826     return WalkResult::advance();
1827   });
1828 
1829   auto resultWalk = topLevelOp->walk([this](Operation *op) {
1830     for (auto result : op->getResults())
1831       if (failed(sortUseListOrder(result)))
1832         return WalkResult::interrupt();
1833     return WalkResult::advance();
1834   });
1835 
1836   return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
1837 }
1838 
1839 //===----------------------------------------------------------------------===//
1840 // IR Section
1841 
1842 LogicalResult
1843 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
1844                                      Block *block) {
1845   EncodingReader reader(sectionData, fileLoc);
1846 
1847   // A stack of operation regions currently being read from the bytecode.
1848   std::vector<RegionReadState> regionStack;
1849 
1850   // Parse the top-level block using a temporary module operation.
1851   OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
1852   regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
1853   regionStack.back().curBlocks.push_back(moduleOp->getBody());
1854   regionStack.back().curBlock = regionStack.back().curRegion->begin();
1855   if (failed(parseBlockHeader(reader, regionStack.back())))
1856     return failure();
1857   valueScopes.emplace_back();
1858   valueScopes.back().push(regionStack.back());
1859 
1860   // Iteratively parse regions until everything has been resolved.
1861   while (!regionStack.empty())
1862     if (failed(parseRegions(regionStack, regionStack.back())))
1863       return failure();
1864   if (!forwardRefOps.empty()) {
1865     return reader.emitError(
1866         "not all forward unresolved forward operand references");
1867   }
1868 
1869   // Sort use-lists according to what specified in bytecode.
1870   if (failed(processUseLists(*moduleOp)))
1871     return reader.emitError(
1872         "parsed use-list orders were invalid and could not be applied");
1873 
1874   // Resolve dialect version.
1875   for (const BytecodeDialect &byteCodeDialect : dialects) {
1876     // Parsing is complete, give an opportunity to each dialect to visit the
1877     // IR and perform upgrades.
1878     if (!byteCodeDialect.loadedVersion)
1879       continue;
1880     if (byteCodeDialect.interface &&
1881         failed(byteCodeDialect.interface->upgradeFromVersion(
1882             *moduleOp, *byteCodeDialect.loadedVersion)))
1883       return failure();
1884   }
1885 
1886   // Verify that the parsed operations are valid.
1887   if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
1888     return failure();
1889 
1890   // Splice the parsed operations over to the provided top-level block.
1891   auto &parsedOps = moduleOp->getBody()->getOperations();
1892   auto &destOps = block->getOperations();
1893   destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
1894   return success();
1895 }
1896 
1897 LogicalResult
1898 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
1899                                    RegionReadState &readState) {
1900   // Process regions, blocks, and operations until the end or if a nested
1901   // region is encountered. In this case we push a new state in regionStack and
1902   // return, the processing of the current region will resume afterward.
1903   for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
1904     // If the current block hasn't been setup yet, parse the header for this
1905     // region. The current block is already setup when this function was
1906     // interrupted to recurse down in a nested region and we resume the current
1907     // block after processing the nested region.
1908     if (readState.curBlock == Region::iterator()) {
1909       if (failed(parseRegion(readState)))
1910         return failure();
1911 
1912       // If the region is empty, there is nothing to more to do.
1913       if (readState.curRegion->empty())
1914         continue;
1915     }
1916 
1917     // Parse the blocks within the region.
1918     EncodingReader &reader = *readState.reader;
1919     do {
1920       while (readState.numOpsRemaining--) {
1921         // Read in the next operation. We don't read its regions directly, we
1922         // handle those afterwards as necessary.
1923         bool isIsolatedFromAbove = false;
1924         FailureOr<Operation *> op =
1925             parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
1926         if (failed(op))
1927           return failure();
1928 
1929         // If the op has regions, add it to the stack for processing and return:
1930         // we stop the processing of the current region and resume it after the
1931         // inner one is completed. Unless LazyLoading is activated in which case
1932         // nested region parsing is delayed.
1933         if ((*op)->getNumRegions()) {
1934           RegionReadState childState(*op, &reader, isIsolatedFromAbove);
1935 
1936           // Isolated regions are encoded as a section in version 2 and above.
1937           if (version >= 2 && isIsolatedFromAbove) {
1938             bytecode::Section::ID sectionID;
1939             ArrayRef<uint8_t> sectionData;
1940             if (failed(reader.parseSection(sectionID, sectionData)))
1941               return failure();
1942             if (sectionID != bytecode::Section::kIR)
1943               return emitError(fileLoc, "expected IR section for region");
1944             childState.owningReader =
1945                 std::make_unique<EncodingReader>(sectionData, fileLoc);
1946             childState.reader = childState.owningReader.get();
1947           }
1948 
1949           if (lazyLoading) {
1950             // If the user has a callback set, they have the opportunity
1951             // to control lazyloading as we go.
1952             if (!lazyOpsCallback || !lazyOpsCallback(*op)) {
1953               lazyLoadableOps.push_back(
1954                   std::make_pair(*op, std::move(childState)));
1955               lazyLoadableOpsMap.try_emplace(*op,
1956                                              std::prev(lazyLoadableOps.end()));
1957               continue;
1958             }
1959           }
1960           regionStack.push_back(std::move(childState));
1961 
1962           // If the op is isolated from above, push a new value scope.
1963           if (isIsolatedFromAbove)
1964             valueScopes.emplace_back();
1965           return success();
1966         }
1967       }
1968 
1969       // Move to the next block of the region.
1970       if (++readState.curBlock == readState.curRegion->end())
1971         break;
1972       if (failed(parseBlockHeader(reader, readState)))
1973         return failure();
1974     } while (true);
1975 
1976     // Reset the current block and any values reserved for this region.
1977     readState.curBlock = {};
1978     valueScopes.back().pop(readState);
1979   }
1980 
1981   // When the regions have been fully parsed, pop them off of the read stack. If
1982   // the regions were isolated from above, we also pop the last value scope.
1983   if (readState.isIsolatedFromAbove) {
1984     assert(!valueScopes.empty() && "Expect a valueScope after reading region");
1985     valueScopes.pop_back();
1986   }
1987   assert(!regionStack.empty() && "Expect a regionStack after reading region");
1988   regionStack.pop_back();
1989   return success();
1990 }
1991 
1992 FailureOr<Operation *>
1993 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
1994                                             RegionReadState &readState,
1995                                             bool &isIsolatedFromAbove) {
1996   // Parse the name of the operation.
1997   FailureOr<OperationName> opName = parseOpName(reader);
1998   if (failed(opName))
1999     return failure();
2000 
2001   // Parse the operation mask, which indicates which components of the operation
2002   // are present.
2003   uint8_t opMask;
2004   if (failed(reader.parseByte(opMask)))
2005     return failure();
2006 
2007   /// Parse the location.
2008   LocationAttr opLoc;
2009   if (failed(parseAttribute(reader, opLoc)))
2010     return failure();
2011 
2012   // With the location and name resolved, we can start building the operation
2013   // state.
2014   OperationState opState(opLoc, *opName);
2015 
2016   // Parse the attributes of the operation.
2017   if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
2018     DictionaryAttr dictAttr;
2019     if (failed(parseAttribute(reader, dictAttr)))
2020       return failure();
2021     opState.attributes = dictAttr;
2022   }
2023 
2024   /// Parse the results of the operation.
2025   if (opMask & bytecode::OpEncodingMask::kHasResults) {
2026     uint64_t numResults;
2027     if (failed(reader.parseVarInt(numResults)))
2028       return failure();
2029     opState.types.resize(numResults);
2030     for (int i = 0, e = numResults; i < e; ++i)
2031       if (failed(parseType(reader, opState.types[i])))
2032         return failure();
2033   }
2034 
2035   /// Parse the operands of the operation.
2036   if (opMask & bytecode::OpEncodingMask::kHasOperands) {
2037     uint64_t numOperands;
2038     if (failed(reader.parseVarInt(numOperands)))
2039       return failure();
2040     opState.operands.resize(numOperands);
2041     for (int i = 0, e = numOperands; i < e; ++i)
2042       if (!(opState.operands[i] = parseOperand(reader)))
2043         return failure();
2044   }
2045 
2046   /// Parse the successors of the operation.
2047   if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
2048     uint64_t numSuccs;
2049     if (failed(reader.parseVarInt(numSuccs)))
2050       return failure();
2051     opState.successors.resize(numSuccs);
2052     for (int i = 0, e = numSuccs; i < e; ++i) {
2053       if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
2054                             "successor")))
2055         return failure();
2056     }
2057   }
2058 
2059   /// Parse the use-list orders for the results of the operation. Use-list
2060   /// orders are available since version 3 of the bytecode.
2061   std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2062   if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
2063     size_t numResults = opState.types.size();
2064     auto parseResult = parseUseListOrderForRange(reader, numResults);
2065     if (failed(parseResult))
2066       return failure();
2067     resultIdxToUseListMap = std::move(*parseResult);
2068   }
2069 
2070   /// Parse the regions of the operation.
2071   if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
2072     uint64_t numRegions;
2073     if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2074       return failure();
2075 
2076     opState.regions.reserve(numRegions);
2077     for (int i = 0, e = numRegions; i < e; ++i)
2078       opState.regions.push_back(std::make_unique<Region>());
2079   }
2080 
2081   // Create the operation at the back of the current block.
2082   Operation *op = Operation::create(opState);
2083   readState.curBlock->push_back(op);
2084 
2085   // If the operation had results, update the value references.
2086   if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
2087     return failure();
2088 
2089   /// Store a map for every value that received a custom use-list order from the
2090   /// bytecode file.
2091   if (resultIdxToUseListMap.has_value()) {
2092     for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2093       if (resultIdxToUseListMap->contains(idx)) {
2094         valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
2095                                       resultIdxToUseListMap->at(idx));
2096       }
2097     }
2098   }
2099   return op;
2100 }
2101 
2102 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2103   EncodingReader &reader = *readState.reader;
2104 
2105   // Parse the number of blocks in the region.
2106   uint64_t numBlocks;
2107   if (failed(reader.parseVarInt(numBlocks)))
2108     return failure();
2109 
2110   // If the region is empty, there is nothing else to do.
2111   if (numBlocks == 0)
2112     return success();
2113 
2114   // Parse the number of values defined in this region.
2115   uint64_t numValues;
2116   if (failed(reader.parseVarInt(numValues)))
2117     return failure();
2118   readState.numValues = numValues;
2119 
2120   // Create the blocks within this region. We do this before processing so that
2121   // we can rely on the blocks existing when creating operations.
2122   readState.curBlocks.clear();
2123   readState.curBlocks.reserve(numBlocks);
2124   for (uint64_t i = 0; i < numBlocks; ++i) {
2125     readState.curBlocks.push_back(new Block());
2126     readState.curRegion->push_back(readState.curBlocks.back());
2127   }
2128 
2129   // Prepare the current value scope for this region.
2130   valueScopes.back().push(readState);
2131 
2132   // Parse the entry block of the region.
2133   readState.curBlock = readState.curRegion->begin();
2134   return parseBlockHeader(reader, readState);
2135 }
2136 
2137 LogicalResult
2138 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2139                                        RegionReadState &readState) {
2140   bool hasArgs;
2141   if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2142     return failure();
2143 
2144   // Parse the arguments of the block.
2145   if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
2146     return failure();
2147 
2148   // Uselist orders are available since version 3 of the bytecode.
2149   if (version < 3)
2150     return success();
2151 
2152   uint8_t hasUseListOrders = 0;
2153   if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
2154     return failure();
2155 
2156   if (!hasUseListOrders)
2157     return success();
2158 
2159   Block &blk = *readState.curBlock;
2160   auto argIdxToUseListMap =
2161       parseUseListOrderForRange(reader, blk.getNumArguments());
2162   if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2163     return failure();
2164 
2165   for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2166     if (argIdxToUseListMap->contains(idx))
2167       valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
2168                                     argIdxToUseListMap->at(idx));
2169 
2170   // We don't parse the operations of the block here, that's done elsewhere.
2171   return success();
2172 }
2173 
2174 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2175                                                         Block *block) {
2176   // Parse the value ID for the first argument, and the number of arguments.
2177   uint64_t numArgs;
2178   if (failed(reader.parseVarInt(numArgs)))
2179     return failure();
2180 
2181   SmallVector<Type> argTypes;
2182   SmallVector<Location> argLocs;
2183   argTypes.reserve(numArgs);
2184   argLocs.reserve(numArgs);
2185 
2186   Location unknownLoc = UnknownLoc::get(config.getContext());
2187   while (numArgs--) {
2188     Type argType;
2189     LocationAttr argLoc = unknownLoc;
2190     if (version > 3) {
2191       // Parse the type with hasLoc flag to determine if it has type.
2192       uint64_t typeIdx;
2193       bool hasLoc;
2194       if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2195           !(argType = attrTypeReader.resolveType(typeIdx)))
2196         return failure();
2197       if (hasLoc && failed(parseAttribute(reader, argLoc)))
2198         return failure();
2199     } else {
2200       // All args has type and location.
2201       if (failed(parseType(reader, argType)) ||
2202           failed(parseAttribute(reader, argLoc)))
2203         return failure();
2204     }
2205     argTypes.push_back(argType);
2206     argLocs.push_back(argLoc);
2207   }
2208   block->addArguments(argTypes, argLocs);
2209   return defineValues(reader, block->getArguments());
2210 }
2211 
2212 //===----------------------------------------------------------------------===//
2213 // Value Processing
2214 
2215 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2216   std::vector<Value> &values = valueScopes.back().values;
2217   Value *value = nullptr;
2218   if (failed(parseEntry(reader, values, value, "value")))
2219     return Value();
2220 
2221   // Create a new forward reference if necessary.
2222   if (!*value)
2223     *value = createForwardRef();
2224   return *value;
2225 }
2226 
2227 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2228                                                  ValueRange newValues) {
2229   ValueScope &valueScope = valueScopes.back();
2230   std::vector<Value> &values = valueScope.values;
2231 
2232   unsigned &valueID = valueScope.nextValueIDs.back();
2233   unsigned valueIDEnd = valueID + newValues.size();
2234   if (valueIDEnd > values.size()) {
2235     return reader.emitError(
2236         "value index range was outside of the expected range for "
2237         "the parent region, got [",
2238         valueID, ", ", valueIDEnd, "), but the maximum index was ",
2239         values.size() - 1);
2240   }
2241 
2242   // Assign the values and update any forward references.
2243   for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2244     Value newValue = newValues[i];
2245 
2246     // Check to see if a definition for this value already exists.
2247     if (Value oldValue = std::exchange(values[valueID], newValue)) {
2248       Operation *forwardRefOp = oldValue.getDefiningOp();
2249 
2250       // Assert that this is a forward reference operation. Given how we compute
2251       // definition ids (incrementally as we parse), it shouldn't be possible
2252       // for the value to be defined any other way.
2253       assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2254              "value index was already defined?");
2255 
2256       oldValue.replaceAllUsesWith(newValue);
2257       forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
2258     }
2259   }
2260   return success();
2261 }
2262 
2263 Value BytecodeReader::Impl::createForwardRef() {
2264   // Check for an avaliable existing operation to use. Otherwise, create a new
2265   // fake operation to use for the reference.
2266   if (!openForwardRefOps.empty()) {
2267     Operation *op = &openForwardRefOps.back();
2268     op->moveBefore(&forwardRefOps, forwardRefOps.end());
2269   } else {
2270     forwardRefOps.push_back(Operation::create(forwardRefOpState));
2271   }
2272   return forwardRefOps.back().getResult(0);
2273 }
2274 
2275 //===----------------------------------------------------------------------===//
2276 // Entry Points
2277 //===----------------------------------------------------------------------===//
2278 
2279 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
2280 
2281 BytecodeReader::BytecodeReader(
2282     llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2283     const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2284   Location sourceFileLoc =
2285       FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2286                           /*line=*/0, /*column=*/0);
2287   impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
2288                                 bufferOwnerRef);
2289 }
2290 
2291 LogicalResult BytecodeReader::readTopLevel(
2292     Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2293   return impl->read(block, lazyOpsCallback);
2294 }
2295 
2296 int64_t BytecodeReader::getNumOpsToMaterialize() const {
2297   return impl->getNumOpsToMaterialize();
2298 }
2299 
2300 bool BytecodeReader::isMaterializable(Operation *op) {
2301   return impl->isMaterializable(op);
2302 }
2303 
2304 LogicalResult BytecodeReader::materialize(
2305     Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2306   return impl->materialize(op, lazyOpsCallback);
2307 }
2308 
2309 LogicalResult
2310 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
2311   return impl->finalize(shouldMaterialize);
2312 }
2313 
2314 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2315   return buffer.getBuffer().startswith("ML\xefR");
2316 }
2317 
2318 /// Read the bytecode from the provided memory buffer reference.
2319 /// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2320 /// and may be used to extend the lifetime of the buffer.
2321 static LogicalResult
2322 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2323                      const ParserConfig &config,
2324                      const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2325   Location sourceFileLoc =
2326       FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2327                           /*line=*/0, /*column=*/0);
2328   if (!isBytecode(buffer)) {
2329     return emitError(sourceFileLoc,
2330                      "input buffer is not an MLIR bytecode file");
2331   }
2332 
2333   BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2334                               buffer, bufferOwnerRef);
2335   return reader.read(block, /*lazyOpsCallback=*/nullptr);
2336 }
2337 
2338 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2339                                      const ParserConfig &config) {
2340   return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2341 }
2342 LogicalResult
2343 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2344                        Block *block, const ParserConfig &config) {
2345   return readBytecodeFileImpl(
2346       *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
2347       sourceMgr);
2348 }
2349