xref: /llvm-project/mlir/include/mlir/Bytecode/BytecodeImplementation.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header defines various interfaces and utilities necessary for dialects
10 // to hook into bytecode serialization.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15 #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectInterface.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Twine.h"
24 
25 namespace mlir {
26 //===--------------------------------------------------------------------===//
27 // Dialect Version Interface.
28 //===--------------------------------------------------------------------===//
29 
30 /// This class is used to represent the version of a dialect, for the purpose
31 /// of polymorphic destruction.
32 class DialectVersion {
33 public:
34   virtual ~DialectVersion() = default;
35 };
36 
37 //===----------------------------------------------------------------------===//
38 // DialectBytecodeReader
39 //===----------------------------------------------------------------------===//
40 
41 /// This class defines a virtual interface for reading a bytecode stream,
42 /// providing hooks into the bytecode reader. As such, this class should only be
43 /// derived and defined by the main bytecode reader, users (i.e. dialects)
44 /// should generally only interact with this class via the
45 /// BytecodeDialectInterface below.
46 class DialectBytecodeReader {
47 public:
48   virtual ~DialectBytecodeReader() = default;
49 
50   /// Emit an error to the reader.
51   virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
52 
53   /// Retrieve the dialect version by name if available.
54   virtual FailureOr<const DialectVersion *>
55   getDialectVersion(StringRef dialectName) const = 0;
56   template <class T>
getDialectVersion()57   FailureOr<const DialectVersion *> getDialectVersion() const {
58     return getDialectVersion(T::getDialectNamespace());
59   }
60 
61   /// Retrieve the context associated to the reader.
62   virtual MLIRContext *getContext() const = 0;
63 
64   /// Return the bytecode version being read.
65   virtual uint64_t getBytecodeVersion() const = 0;
66 
67   /// Read out a list of elements, invoking the provided callback for each
68   /// element. The callback function may be in any of the following forms:
69   ///   * LogicalResult(T &)
70   ///   * FailureOr<T>()
71   template <typename T, typename CallbackFn>
readList(SmallVectorImpl<T> & result,CallbackFn && callback)72   LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
73     uint64_t size;
74     if (failed(readVarInt(size)))
75       return failure();
76     result.reserve(size);
77 
78     for (uint64_t i = 0; i < size; ++i) {
79       // Check if the callback uses FailureOr, or populates the result by
80       // reference.
81       if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
82         T element = {};
83         if (failed(callback(element)))
84           return failure();
85         result.emplace_back(std::move(element));
86       } else {
87         FailureOr<T> element = callback();
88         if (failed(element))
89           return failure();
90         result.emplace_back(std::move(*element));
91       }
92     }
93     return success();
94   }
95 
96   //===--------------------------------------------------------------------===//
97   // IR
98   //===--------------------------------------------------------------------===//
99 
100   /// Read a reference to the given attribute.
101   virtual LogicalResult readAttribute(Attribute &result) = 0;
102   /// Read an optional reference to the given attribute. Returns success even if
103   /// the Attribute isn't present.
104   virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
105 
106   template <typename T>
readAttributes(SmallVectorImpl<T> & attrs)107   LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
108     return readList(attrs, [this](T &attr) { return readAttribute(attr); });
109   }
110   template <typename T>
readAttribute(T & result)111   LogicalResult readAttribute(T &result) {
112     Attribute baseResult;
113     if (failed(readAttribute(baseResult)))
114       return failure();
115     if ((result = dyn_cast<T>(baseResult)))
116       return success();
117     return emitError() << "expected " << llvm::getTypeName<T>()
118                        << ", but got: " << baseResult;
119   }
120   template <typename T>
readOptionalAttribute(T & result)121   LogicalResult readOptionalAttribute(T &result) {
122     Attribute baseResult;
123     if (failed(readOptionalAttribute(baseResult)))
124       return failure();
125     if (!baseResult)
126       return success();
127     if ((result = dyn_cast<T>(baseResult)))
128       return success();
129     return emitError() << "expected " << llvm::getTypeName<T>()
130                        << ", but got: " << baseResult;
131   }
132 
133   /// Read a reference to the given type.
134   virtual LogicalResult readType(Type &result) = 0;
135   template <typename T>
readTypes(SmallVectorImpl<T> & types)136   LogicalResult readTypes(SmallVectorImpl<T> &types) {
137     return readList(types, [this](T &type) { return readType(type); });
138   }
139   template <typename T>
readType(T & result)140   LogicalResult readType(T &result) {
141     Type baseResult;
142     if (failed(readType(baseResult)))
143       return failure();
144     if ((result = dyn_cast<T>(baseResult)))
145       return success();
146     return emitError() << "expected " << llvm::getTypeName<T>()
147                        << ", but got: " << baseResult;
148   }
149 
150   /// Read a handle to a dialect resource.
151   template <typename ResourceT>
readResourceHandle()152   FailureOr<ResourceT> readResourceHandle() {
153     FailureOr<AsmDialectResourceHandle> handle = readResourceHandle();
154     if (failed(handle))
155       return failure();
156     if (auto *result = dyn_cast<ResourceT>(&*handle))
157       return std::move(*result);
158     return emitError() << "provided resource handle differs from the "
159                           "expected resource type";
160   }
161 
162   //===--------------------------------------------------------------------===//
163   // Primitives
164   //===--------------------------------------------------------------------===//
165 
166   /// Read a variable width integer.
167   virtual LogicalResult readVarInt(uint64_t &result) = 0;
168 
169   /// Read a signed variable width integer.
170   virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
readSignedVarInts(SmallVectorImpl<int64_t> & result)171   LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
172     return readList(result,
173                     [this](int64_t &value) { return readSignedVarInt(value); });
174   }
175 
176   /// Parse a variable length encoded integer whose low bit is used to encode an
177   /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
readVarIntWithFlag(uint64_t & result,bool & flag)178   LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) {
179     if (failed(readVarInt(result)))
180       return failure();
181     flag = result & 1;
182     result >>= 1;
183     return success();
184   }
185 
186   /// Read a "small" sparse array of integer <= 32 bits elements, where
187   /// index/value pairs can be compressed when the array is small.
188   /// Note that only some position of the array will be read and the ones
189   /// not stored in the bytecode are gonne be left untouched.
190   /// If the provided array is too small for the stored indices, an error
191   /// will be returned.
192   template <typename T>
readSparseArray(MutableArrayRef<T> array)193   LogicalResult readSparseArray(MutableArrayRef<T> array) {
194     static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
195     static_assert(std::is_integral<T>::value, "expects integer");
196     uint64_t nonZeroesCount;
197     bool useSparseEncoding;
198     if (failed(readVarIntWithFlag(nonZeroesCount, useSparseEncoding)))
199       return failure();
200     if (nonZeroesCount == 0)
201       return success();
202     if (!useSparseEncoding) {
203       // This is a simple dense array.
204       if (nonZeroesCount > array.size()) {
205         emitError("trying to read an array of ")
206             << nonZeroesCount << " but only " << array.size()
207             << " storage available.";
208         return failure();
209       }
210       for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
211         uint64_t value;
212         if (failed(readVarInt(value)))
213           return failure();
214         array[index] = value;
215       }
216       return success();
217     }
218     // Read sparse encoding
219     // This is the number of bits used for packing the index with the value.
220     uint64_t indexBitSize;
221     if (failed(readVarInt(indexBitSize)))
222       return failure();
223     constexpr uint64_t maxIndexBitSize = 8;
224     if (indexBitSize > maxIndexBitSize) {
225       emitError("reading sparse array with indexing above 8 bits: ")
226           << indexBitSize;
227       return failure();
228     }
229     for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
230       (void)count;
231       uint64_t indexValuePair;
232       if (failed(readVarInt(indexValuePair)))
233         return failure();
234       uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
235       uint64_t value = indexValuePair >> indexBitSize;
236       if (index >= array.size()) {
237         emitError("reading a sparse array found index ")
238             << index << " but only " << array.size() << " storage available.";
239         return failure();
240       }
241       array[index] = value;
242     }
243     return success();
244   }
245 
246   /// Read an APInt that is known to have been encoded with the given width.
247   virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
248 
249   /// Read an APFloat that is known to have been encoded with the given
250   /// semantics.
251   virtual FailureOr<APFloat>
252   readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0;
253 
254   /// Read a string from the bytecode.
255   virtual LogicalResult readString(StringRef &result) = 0;
256 
257   /// Read a blob from the bytecode.
258   virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
259 
260   /// Read a bool from the bytecode.
261   virtual LogicalResult readBool(bool &result) = 0;
262 
263 private:
264   /// Read a handle to a dialect resource.
265   virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
266 };
267 
268 //===----------------------------------------------------------------------===//
269 // DialectBytecodeWriter
270 //===----------------------------------------------------------------------===//
271 
272 /// This class defines a virtual interface for writing to a bytecode stream,
273 /// providing hooks into the bytecode writer. As such, this class should only be
274 /// derived and defined by the main bytecode writer, users (i.e. dialects)
275 /// should generally only interact with this class via the
276 /// BytecodeDialectInterface below.
277 class DialectBytecodeWriter {
278 public:
279   virtual ~DialectBytecodeWriter() = default;
280 
281   //===--------------------------------------------------------------------===//
282   // IR
283   //===--------------------------------------------------------------------===//
284 
285   /// Write out a list of elements, invoking the provided callback for each
286   /// element.
287   template <typename RangeT, typename CallbackFn>
writeList(RangeT && range,CallbackFn && callback)288   void writeList(RangeT &&range, CallbackFn &&callback) {
289     writeVarInt(llvm::size(range));
290     for (auto &element : range)
291       callback(element);
292   }
293 
294   /// Write a reference to the given attribute.
295   virtual void writeAttribute(Attribute attr) = 0;
296   virtual void writeOptionalAttribute(Attribute attr) = 0;
297   template <typename T>
writeAttributes(ArrayRef<T> attrs)298   void writeAttributes(ArrayRef<T> attrs) {
299     writeList(attrs, [this](T attr) { writeAttribute(attr); });
300   }
301 
302   /// Write a reference to the given type.
303   virtual void writeType(Type type) = 0;
304   template <typename T>
writeTypes(ArrayRef<T> types)305   void writeTypes(ArrayRef<T> types) {
306     writeList(types, [this](T type) { writeType(type); });
307   }
308 
309   /// Write the given handle to a dialect resource.
310   virtual void
311   writeResourceHandle(const AsmDialectResourceHandle &resource) = 0;
312 
313   //===--------------------------------------------------------------------===//
314   // Primitives
315   //===--------------------------------------------------------------------===//
316 
317   /// Write a variable width integer to the output stream. This should be the
318   /// preferred method for emitting integers whenever possible.
319   virtual void writeVarInt(uint64_t value) = 0;
320 
321   /// Write a signed variable width integer to the output stream. This should be
322   /// the preferred method for emitting signed integers whenever possible.
323   virtual void writeSignedVarInt(int64_t value) = 0;
writeSignedVarInts(ArrayRef<int64_t> value)324   void writeSignedVarInts(ArrayRef<int64_t> value) {
325     writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
326   }
327 
328   /// Write a VarInt and a flag packed together.
writeVarIntWithFlag(uint64_t value,bool flag)329   void writeVarIntWithFlag(uint64_t value, bool flag) {
330     writeVarInt((value << 1) | (flag ? 1 : 0));
331   }
332 
333   /// Write out a "small" sparse array of integer <= 32 bits elements, where
334   /// index/value pairs can be compressed when the array is small. This method
335   /// will scan the array multiple times and should not be used for large
336   /// arrays. The optional provided "zero" can be used to adjust for the
337   /// expected repeated value. We assume here that the array size fits in a 32
338   /// bits integer.
339   template <typename T>
writeSparseArray(ArrayRef<T> array)340   void writeSparseArray(ArrayRef<T> array) {
341     static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
342     static_assert(std::is_integral<T>::value, "expects integer");
343     uint32_t size = array.size();
344     uint32_t nonZeroesCount = 0, lastIndex = 0;
345     for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
346       if (!array[index])
347         continue;
348       nonZeroesCount++;
349       lastIndex = index;
350     }
351     // If the last position is too large, or the array isn't at least 50%
352     // sparse, emit it with a dense encoding.
353     if (lastIndex > 256 || nonZeroesCount > size / 2) {
354       // Emit the array size and a flag which indicates whether it is sparse.
355       writeVarIntWithFlag(size, false);
356       for (const T &elt : array)
357         writeVarInt(elt);
358       return;
359     }
360     // Emit sparse: first the number of elements we'll write and a flag
361     // indicating it is a sparse encoding.
362     writeVarIntWithFlag(nonZeroesCount, true);
363     if (nonZeroesCount == 0)
364       return;
365     // This is the number of bits used for packing the index with the value.
366     int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
367     writeVarInt(indexBitSize);
368     for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
369       T value = array[index];
370       if (!value)
371         continue;
372       uint64_t indexValuePair = (value << indexBitSize) | (index);
373       writeVarInt(indexValuePair);
374     }
375   }
376 
377   /// Write an APInt to the bytecode stream whose bitwidth will be known
378   /// externally at read time. This method is useful for encoding APInt values
379   /// when the width is known via external means, such as via a type. This
380   /// method should generally only be invoked if you need an APInt, otherwise
381   /// use the varint methods above. APInt values are generally encoded using
382   /// zigzag encoding, to enable more efficient encodings for negative values.
383   virtual void writeAPIntWithKnownWidth(const APInt &value) = 0;
384 
385   /// Write an APFloat to the bytecode stream whose semantics will be known
386   /// externally at read time. This method is useful for encoding APFloat values
387   /// when the semantics are known via external means, such as via a type.
388   virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0;
389 
390   /// Write a string to the bytecode, which is owned by the caller and is
391   /// guaranteed to not die before the end of the bytecode process. This should
392   /// only be called if such a guarantee can be made, such as when the string is
393   /// owned by an attribute or type.
394   virtual void writeOwnedString(StringRef str) = 0;
395 
396   /// Write a blob to the bytecode, which is owned by the caller and is
397   /// guaranteed to not die before the end of the bytecode process. The blob is
398   /// written as-is, with no additional compression or compaction.
399   virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
400 
401   /// Write a bool to the output stream.
402   virtual void writeOwnedBool(bool value) = 0;
403 
404   /// Return the bytecode version being emitted for.
405   virtual int64_t getBytecodeVersion() const = 0;
406 
407   /// Retrieve the dialect version by name if available.
408   virtual FailureOr<const DialectVersion *>
409   getDialectVersion(StringRef dialectName) const = 0;
410 
411   template <class T>
getDialectVersion()412   FailureOr<const DialectVersion *> getDialectVersion() const {
413     return getDialectVersion(T::getDialectNamespace());
414   }
415 };
416 
417 //===----------------------------------------------------------------------===//
418 // BytecodeDialectInterface
419 //===----------------------------------------------------------------------===//
420 
421 class BytecodeDialectInterface
422     : public DialectInterface::Base<BytecodeDialectInterface> {
423 public:
424   using Base::Base;
425 
426   //===--------------------------------------------------------------------===//
427   // Reading
428   //===--------------------------------------------------------------------===//
429 
430   /// Read an attribute belonging to this dialect from the given reader. This
431   /// method should return null in the case of failure. Optionally, the dialect
432   /// version can be accessed through the reader.
readAttribute(DialectBytecodeReader & reader)433   virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
434     reader.emitError() << "dialect " << getDialect()->getNamespace()
435                        << " does not support reading attributes from bytecode";
436     return Attribute();
437   }
438 
439   /// Read a type belonging to this dialect from the given reader. This method
440   /// should return null in the case of failure. Optionally, the dialect version
441   /// can be accessed thorugh the reader.
readType(DialectBytecodeReader & reader)442   virtual Type readType(DialectBytecodeReader &reader) const {
443     reader.emitError() << "dialect " << getDialect()->getNamespace()
444                        << " does not support reading types from bytecode";
445     return Type();
446   }
447 
448   //===--------------------------------------------------------------------===//
449   // Writing
450   //===--------------------------------------------------------------------===//
451 
452   /// Write the given attribute, which belongs to this dialect, to the given
453   /// writer. This method may return failure to indicate that the given
454   /// attribute could not be encoded, in which case the textual format will be
455   /// used to encode this attribute instead.
writeAttribute(Attribute attr,DialectBytecodeWriter & writer)456   virtual LogicalResult writeAttribute(Attribute attr,
457                                        DialectBytecodeWriter &writer) const {
458     return failure();
459   }
460 
461   /// Write the given type, which belongs to this dialect, to the given writer.
462   /// This method may return failure to indicate that the given type could not
463   /// be encoded, in which case the textual format will be used to encode this
464   /// type instead.
writeType(Type type,DialectBytecodeWriter & writer)465   virtual LogicalResult writeType(Type type,
466                                   DialectBytecodeWriter &writer) const {
467     return failure();
468   }
469 
470   /// Write the version of this dialect to the given writer.
writeVersion(DialectBytecodeWriter & writer)471   virtual void writeVersion(DialectBytecodeWriter &writer) const {}
472 
473   // Read the version of this dialect from the provided reader and return it as
474   // a `unique_ptr` to a dialect version object.
475   virtual std::unique_ptr<DialectVersion>
readVersion(DialectBytecodeReader & reader)476   readVersion(DialectBytecodeReader &reader) const {
477     reader.emitError("Dialect does not support versioning");
478     return nullptr;
479   }
480 
481   /// Hook invoked after parsing completed, if a version directive was present
482   /// and included an entry for the current dialect. This hook offers the
483   /// opportunity to the dialect to visit the IR and upgrades constructs emitted
484   /// by the version of the dialect corresponding to the provided version.
485   virtual LogicalResult
upgradeFromVersion(Operation * topLevelOp,const DialectVersion & version)486   upgradeFromVersion(Operation *topLevelOp,
487                      const DialectVersion &version) const {
488     return success();
489   }
490 };
491 
492 /// Helper for resource handle reading that returns LogicalResult.
493 template <typename T, typename... Ts>
readResourceHandle(DialectBytecodeReader & reader,FailureOr<T> & value,Ts &&...params)494 static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
495                                         FailureOr<T> &value, Ts &&...params) {
496   FailureOr<T> handle = reader.readResourceHandle<T>();
497   if (failed(handle))
498     return failure();
499   if (auto *result = dyn_cast<T>(&*handle)) {
500     value = std::move(*result);
501     return success();
502   }
503   return failure();
504 }
505 
506 /// Helper method that injects context only if needed, this helps unify some of
507 /// the attribute construction methods.
508 template <typename T, typename... Ts>
get(MLIRContext * context,Ts &&...params)509 auto get(MLIRContext *context, Ts &&...params) {
510   // Prefer a direct `get` method if one exists.
511   if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
512     (void)context;
513     return T::get(std::forward<Ts>(params)...);
514   } else if constexpr (llvm::is_detected<detail::has_get_method, T,
515                                          MLIRContext *, Ts...>::value) {
516     return T::get(context, std::forward<Ts>(params)...);
517   } else {
518     // Otherwise, pass to the base get.
519     return T::Base::get(context, std::forward<Ts>(params)...);
520   }
521 }
522 
523 } // namespace mlir
524 
525 #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
526