1 //===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header defines various interfaces and utilities necessary for dialects
10 // to hook into bytecode serialization.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15 #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
16
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectInterface.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Twine.h"
24
25 namespace mlir {
26 //===--------------------------------------------------------------------===//
27 // Dialect Version Interface.
28 //===--------------------------------------------------------------------===//
29
30 /// This class is used to represent the version of a dialect, for the purpose
31 /// of polymorphic destruction.
32 class DialectVersion {
33 public:
34 virtual ~DialectVersion() = default;
35 };
36
37 //===----------------------------------------------------------------------===//
38 // DialectBytecodeReader
39 //===----------------------------------------------------------------------===//
40
41 /// This class defines a virtual interface for reading a bytecode stream,
42 /// providing hooks into the bytecode reader. As such, this class should only be
43 /// derived and defined by the main bytecode reader, users (i.e. dialects)
44 /// should generally only interact with this class via the
45 /// BytecodeDialectInterface below.
46 class DialectBytecodeReader {
47 public:
48 virtual ~DialectBytecodeReader() = default;
49
50 /// Emit an error to the reader.
51 virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
52
53 /// Retrieve the dialect version by name if available.
54 virtual FailureOr<const DialectVersion *>
55 getDialectVersion(StringRef dialectName) const = 0;
56 template <class T>
getDialectVersion()57 FailureOr<const DialectVersion *> getDialectVersion() const {
58 return getDialectVersion(T::getDialectNamespace());
59 }
60
61 /// Retrieve the context associated to the reader.
62 virtual MLIRContext *getContext() const = 0;
63
64 /// Return the bytecode version being read.
65 virtual uint64_t getBytecodeVersion() const = 0;
66
67 /// Read out a list of elements, invoking the provided callback for each
68 /// element. The callback function may be in any of the following forms:
69 /// * LogicalResult(T &)
70 /// * FailureOr<T>()
71 template <typename T, typename CallbackFn>
readList(SmallVectorImpl<T> & result,CallbackFn && callback)72 LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
73 uint64_t size;
74 if (failed(readVarInt(size)))
75 return failure();
76 result.reserve(size);
77
78 for (uint64_t i = 0; i < size; ++i) {
79 // Check if the callback uses FailureOr, or populates the result by
80 // reference.
81 if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
82 T element = {};
83 if (failed(callback(element)))
84 return failure();
85 result.emplace_back(std::move(element));
86 } else {
87 FailureOr<T> element = callback();
88 if (failed(element))
89 return failure();
90 result.emplace_back(std::move(*element));
91 }
92 }
93 return success();
94 }
95
96 //===--------------------------------------------------------------------===//
97 // IR
98 //===--------------------------------------------------------------------===//
99
100 /// Read a reference to the given attribute.
101 virtual LogicalResult readAttribute(Attribute &result) = 0;
102 /// Read an optional reference to the given attribute. Returns success even if
103 /// the Attribute isn't present.
104 virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
105
106 template <typename T>
readAttributes(SmallVectorImpl<T> & attrs)107 LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
108 return readList(attrs, [this](T &attr) { return readAttribute(attr); });
109 }
110 template <typename T>
readAttribute(T & result)111 LogicalResult readAttribute(T &result) {
112 Attribute baseResult;
113 if (failed(readAttribute(baseResult)))
114 return failure();
115 if ((result = dyn_cast<T>(baseResult)))
116 return success();
117 return emitError() << "expected " << llvm::getTypeName<T>()
118 << ", but got: " << baseResult;
119 }
120 template <typename T>
readOptionalAttribute(T & result)121 LogicalResult readOptionalAttribute(T &result) {
122 Attribute baseResult;
123 if (failed(readOptionalAttribute(baseResult)))
124 return failure();
125 if (!baseResult)
126 return success();
127 if ((result = dyn_cast<T>(baseResult)))
128 return success();
129 return emitError() << "expected " << llvm::getTypeName<T>()
130 << ", but got: " << baseResult;
131 }
132
133 /// Read a reference to the given type.
134 virtual LogicalResult readType(Type &result) = 0;
135 template <typename T>
readTypes(SmallVectorImpl<T> & types)136 LogicalResult readTypes(SmallVectorImpl<T> &types) {
137 return readList(types, [this](T &type) { return readType(type); });
138 }
139 template <typename T>
readType(T & result)140 LogicalResult readType(T &result) {
141 Type baseResult;
142 if (failed(readType(baseResult)))
143 return failure();
144 if ((result = dyn_cast<T>(baseResult)))
145 return success();
146 return emitError() << "expected " << llvm::getTypeName<T>()
147 << ", but got: " << baseResult;
148 }
149
150 /// Read a handle to a dialect resource.
151 template <typename ResourceT>
readResourceHandle()152 FailureOr<ResourceT> readResourceHandle() {
153 FailureOr<AsmDialectResourceHandle> handle = readResourceHandle();
154 if (failed(handle))
155 return failure();
156 if (auto *result = dyn_cast<ResourceT>(&*handle))
157 return std::move(*result);
158 return emitError() << "provided resource handle differs from the "
159 "expected resource type";
160 }
161
162 //===--------------------------------------------------------------------===//
163 // Primitives
164 //===--------------------------------------------------------------------===//
165
166 /// Read a variable width integer.
167 virtual LogicalResult readVarInt(uint64_t &result) = 0;
168
169 /// Read a signed variable width integer.
170 virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
readSignedVarInts(SmallVectorImpl<int64_t> & result)171 LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
172 return readList(result,
173 [this](int64_t &value) { return readSignedVarInt(value); });
174 }
175
176 /// Parse a variable length encoded integer whose low bit is used to encode an
177 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
readVarIntWithFlag(uint64_t & result,bool & flag)178 LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) {
179 if (failed(readVarInt(result)))
180 return failure();
181 flag = result & 1;
182 result >>= 1;
183 return success();
184 }
185
186 /// Read a "small" sparse array of integer <= 32 bits elements, where
187 /// index/value pairs can be compressed when the array is small.
188 /// Note that only some position of the array will be read and the ones
189 /// not stored in the bytecode are gonne be left untouched.
190 /// If the provided array is too small for the stored indices, an error
191 /// will be returned.
192 template <typename T>
readSparseArray(MutableArrayRef<T> array)193 LogicalResult readSparseArray(MutableArrayRef<T> array) {
194 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
195 static_assert(std::is_integral<T>::value, "expects integer");
196 uint64_t nonZeroesCount;
197 bool useSparseEncoding;
198 if (failed(readVarIntWithFlag(nonZeroesCount, useSparseEncoding)))
199 return failure();
200 if (nonZeroesCount == 0)
201 return success();
202 if (!useSparseEncoding) {
203 // This is a simple dense array.
204 if (nonZeroesCount > array.size()) {
205 emitError("trying to read an array of ")
206 << nonZeroesCount << " but only " << array.size()
207 << " storage available.";
208 return failure();
209 }
210 for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
211 uint64_t value;
212 if (failed(readVarInt(value)))
213 return failure();
214 array[index] = value;
215 }
216 return success();
217 }
218 // Read sparse encoding
219 // This is the number of bits used for packing the index with the value.
220 uint64_t indexBitSize;
221 if (failed(readVarInt(indexBitSize)))
222 return failure();
223 constexpr uint64_t maxIndexBitSize = 8;
224 if (indexBitSize > maxIndexBitSize) {
225 emitError("reading sparse array with indexing above 8 bits: ")
226 << indexBitSize;
227 return failure();
228 }
229 for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
230 (void)count;
231 uint64_t indexValuePair;
232 if (failed(readVarInt(indexValuePair)))
233 return failure();
234 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
235 uint64_t value = indexValuePair >> indexBitSize;
236 if (index >= array.size()) {
237 emitError("reading a sparse array found index ")
238 << index << " but only " << array.size() << " storage available.";
239 return failure();
240 }
241 array[index] = value;
242 }
243 return success();
244 }
245
246 /// Read an APInt that is known to have been encoded with the given width.
247 virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
248
249 /// Read an APFloat that is known to have been encoded with the given
250 /// semantics.
251 virtual FailureOr<APFloat>
252 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0;
253
254 /// Read a string from the bytecode.
255 virtual LogicalResult readString(StringRef &result) = 0;
256
257 /// Read a blob from the bytecode.
258 virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
259
260 /// Read a bool from the bytecode.
261 virtual LogicalResult readBool(bool &result) = 0;
262
263 private:
264 /// Read a handle to a dialect resource.
265 virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
266 };
267
268 //===----------------------------------------------------------------------===//
269 // DialectBytecodeWriter
270 //===----------------------------------------------------------------------===//
271
272 /// This class defines a virtual interface for writing to a bytecode stream,
273 /// providing hooks into the bytecode writer. As such, this class should only be
274 /// derived and defined by the main bytecode writer, users (i.e. dialects)
275 /// should generally only interact with this class via the
276 /// BytecodeDialectInterface below.
277 class DialectBytecodeWriter {
278 public:
279 virtual ~DialectBytecodeWriter() = default;
280
281 //===--------------------------------------------------------------------===//
282 // IR
283 //===--------------------------------------------------------------------===//
284
285 /// Write out a list of elements, invoking the provided callback for each
286 /// element.
287 template <typename RangeT, typename CallbackFn>
writeList(RangeT && range,CallbackFn && callback)288 void writeList(RangeT &&range, CallbackFn &&callback) {
289 writeVarInt(llvm::size(range));
290 for (auto &element : range)
291 callback(element);
292 }
293
294 /// Write a reference to the given attribute.
295 virtual void writeAttribute(Attribute attr) = 0;
296 virtual void writeOptionalAttribute(Attribute attr) = 0;
297 template <typename T>
writeAttributes(ArrayRef<T> attrs)298 void writeAttributes(ArrayRef<T> attrs) {
299 writeList(attrs, [this](T attr) { writeAttribute(attr); });
300 }
301
302 /// Write a reference to the given type.
303 virtual void writeType(Type type) = 0;
304 template <typename T>
writeTypes(ArrayRef<T> types)305 void writeTypes(ArrayRef<T> types) {
306 writeList(types, [this](T type) { writeType(type); });
307 }
308
309 /// Write the given handle to a dialect resource.
310 virtual void
311 writeResourceHandle(const AsmDialectResourceHandle &resource) = 0;
312
313 //===--------------------------------------------------------------------===//
314 // Primitives
315 //===--------------------------------------------------------------------===//
316
317 /// Write a variable width integer to the output stream. This should be the
318 /// preferred method for emitting integers whenever possible.
319 virtual void writeVarInt(uint64_t value) = 0;
320
321 /// Write a signed variable width integer to the output stream. This should be
322 /// the preferred method for emitting signed integers whenever possible.
323 virtual void writeSignedVarInt(int64_t value) = 0;
writeSignedVarInts(ArrayRef<int64_t> value)324 void writeSignedVarInts(ArrayRef<int64_t> value) {
325 writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
326 }
327
328 /// Write a VarInt and a flag packed together.
writeVarIntWithFlag(uint64_t value,bool flag)329 void writeVarIntWithFlag(uint64_t value, bool flag) {
330 writeVarInt((value << 1) | (flag ? 1 : 0));
331 }
332
333 /// Write out a "small" sparse array of integer <= 32 bits elements, where
334 /// index/value pairs can be compressed when the array is small. This method
335 /// will scan the array multiple times and should not be used for large
336 /// arrays. The optional provided "zero" can be used to adjust for the
337 /// expected repeated value. We assume here that the array size fits in a 32
338 /// bits integer.
339 template <typename T>
writeSparseArray(ArrayRef<T> array)340 void writeSparseArray(ArrayRef<T> array) {
341 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
342 static_assert(std::is_integral<T>::value, "expects integer");
343 uint32_t size = array.size();
344 uint32_t nonZeroesCount = 0, lastIndex = 0;
345 for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
346 if (!array[index])
347 continue;
348 nonZeroesCount++;
349 lastIndex = index;
350 }
351 // If the last position is too large, or the array isn't at least 50%
352 // sparse, emit it with a dense encoding.
353 if (lastIndex > 256 || nonZeroesCount > size / 2) {
354 // Emit the array size and a flag which indicates whether it is sparse.
355 writeVarIntWithFlag(size, false);
356 for (const T &elt : array)
357 writeVarInt(elt);
358 return;
359 }
360 // Emit sparse: first the number of elements we'll write and a flag
361 // indicating it is a sparse encoding.
362 writeVarIntWithFlag(nonZeroesCount, true);
363 if (nonZeroesCount == 0)
364 return;
365 // This is the number of bits used for packing the index with the value.
366 int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
367 writeVarInt(indexBitSize);
368 for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
369 T value = array[index];
370 if (!value)
371 continue;
372 uint64_t indexValuePair = (value << indexBitSize) | (index);
373 writeVarInt(indexValuePair);
374 }
375 }
376
377 /// Write an APInt to the bytecode stream whose bitwidth will be known
378 /// externally at read time. This method is useful for encoding APInt values
379 /// when the width is known via external means, such as via a type. This
380 /// method should generally only be invoked if you need an APInt, otherwise
381 /// use the varint methods above. APInt values are generally encoded using
382 /// zigzag encoding, to enable more efficient encodings for negative values.
383 virtual void writeAPIntWithKnownWidth(const APInt &value) = 0;
384
385 /// Write an APFloat to the bytecode stream whose semantics will be known
386 /// externally at read time. This method is useful for encoding APFloat values
387 /// when the semantics are known via external means, such as via a type.
388 virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0;
389
390 /// Write a string to the bytecode, which is owned by the caller and is
391 /// guaranteed to not die before the end of the bytecode process. This should
392 /// only be called if such a guarantee can be made, such as when the string is
393 /// owned by an attribute or type.
394 virtual void writeOwnedString(StringRef str) = 0;
395
396 /// Write a blob to the bytecode, which is owned by the caller and is
397 /// guaranteed to not die before the end of the bytecode process. The blob is
398 /// written as-is, with no additional compression or compaction.
399 virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
400
401 /// Write a bool to the output stream.
402 virtual void writeOwnedBool(bool value) = 0;
403
404 /// Return the bytecode version being emitted for.
405 virtual int64_t getBytecodeVersion() const = 0;
406
407 /// Retrieve the dialect version by name if available.
408 virtual FailureOr<const DialectVersion *>
409 getDialectVersion(StringRef dialectName) const = 0;
410
411 template <class T>
getDialectVersion()412 FailureOr<const DialectVersion *> getDialectVersion() const {
413 return getDialectVersion(T::getDialectNamespace());
414 }
415 };
416
417 //===----------------------------------------------------------------------===//
418 // BytecodeDialectInterface
419 //===----------------------------------------------------------------------===//
420
421 class BytecodeDialectInterface
422 : public DialectInterface::Base<BytecodeDialectInterface> {
423 public:
424 using Base::Base;
425
426 //===--------------------------------------------------------------------===//
427 // Reading
428 //===--------------------------------------------------------------------===//
429
430 /// Read an attribute belonging to this dialect from the given reader. This
431 /// method should return null in the case of failure. Optionally, the dialect
432 /// version can be accessed through the reader.
readAttribute(DialectBytecodeReader & reader)433 virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
434 reader.emitError() << "dialect " << getDialect()->getNamespace()
435 << " does not support reading attributes from bytecode";
436 return Attribute();
437 }
438
439 /// Read a type belonging to this dialect from the given reader. This method
440 /// should return null in the case of failure. Optionally, the dialect version
441 /// can be accessed thorugh the reader.
readType(DialectBytecodeReader & reader)442 virtual Type readType(DialectBytecodeReader &reader) const {
443 reader.emitError() << "dialect " << getDialect()->getNamespace()
444 << " does not support reading types from bytecode";
445 return Type();
446 }
447
448 //===--------------------------------------------------------------------===//
449 // Writing
450 //===--------------------------------------------------------------------===//
451
452 /// Write the given attribute, which belongs to this dialect, to the given
453 /// writer. This method may return failure to indicate that the given
454 /// attribute could not be encoded, in which case the textual format will be
455 /// used to encode this attribute instead.
writeAttribute(Attribute attr,DialectBytecodeWriter & writer)456 virtual LogicalResult writeAttribute(Attribute attr,
457 DialectBytecodeWriter &writer) const {
458 return failure();
459 }
460
461 /// Write the given type, which belongs to this dialect, to the given writer.
462 /// This method may return failure to indicate that the given type could not
463 /// be encoded, in which case the textual format will be used to encode this
464 /// type instead.
writeType(Type type,DialectBytecodeWriter & writer)465 virtual LogicalResult writeType(Type type,
466 DialectBytecodeWriter &writer) const {
467 return failure();
468 }
469
470 /// Write the version of this dialect to the given writer.
writeVersion(DialectBytecodeWriter & writer)471 virtual void writeVersion(DialectBytecodeWriter &writer) const {}
472
473 // Read the version of this dialect from the provided reader and return it as
474 // a `unique_ptr` to a dialect version object.
475 virtual std::unique_ptr<DialectVersion>
readVersion(DialectBytecodeReader & reader)476 readVersion(DialectBytecodeReader &reader) const {
477 reader.emitError("Dialect does not support versioning");
478 return nullptr;
479 }
480
481 /// Hook invoked after parsing completed, if a version directive was present
482 /// and included an entry for the current dialect. This hook offers the
483 /// opportunity to the dialect to visit the IR and upgrades constructs emitted
484 /// by the version of the dialect corresponding to the provided version.
485 virtual LogicalResult
upgradeFromVersion(Operation * topLevelOp,const DialectVersion & version)486 upgradeFromVersion(Operation *topLevelOp,
487 const DialectVersion &version) const {
488 return success();
489 }
490 };
491
492 /// Helper for resource handle reading that returns LogicalResult.
493 template <typename T, typename... Ts>
readResourceHandle(DialectBytecodeReader & reader,FailureOr<T> & value,Ts &&...params)494 static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
495 FailureOr<T> &value, Ts &&...params) {
496 FailureOr<T> handle = reader.readResourceHandle<T>();
497 if (failed(handle))
498 return failure();
499 if (auto *result = dyn_cast<T>(&*handle)) {
500 value = std::move(*result);
501 return success();
502 }
503 return failure();
504 }
505
506 /// Helper method that injects context only if needed, this helps unify some of
507 /// the attribute construction methods.
508 template <typename T, typename... Ts>
get(MLIRContext * context,Ts &&...params)509 auto get(MLIRContext *context, Ts &&...params) {
510 // Prefer a direct `get` method if one exists.
511 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
512 (void)context;
513 return T::get(std::forward<Ts>(params)...);
514 } else if constexpr (llvm::is_detected<detail::has_get_method, T,
515 MLIRContext *, Ts...>::value) {
516 return T::get(context, std::forward<Ts>(params)...);
517 } else {
518 // Otherwise, pass to the base get.
519 return T::Base::get(context, std::forward<Ts>(params)...);
520 }
521 }
522
523 } // namespace mlir
524
525 #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
526