xref: /llvm-project/mlir/include/mlir/Bytecode/BytecodeWriter.h (revision 65bc259a97cd8cc70907b65f59aff728245ba9c0)
1 //===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header defines interfaces to write MLIR bytecode files/streams.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_BYTECODE_BYTECODEWRITER_H
14 #define MLIR_BYTECODE_BYTECODEWRITER_H
15 
16 #include "mlir/IR/AsmState.h"
17 #include "llvm/Config/llvm-config.h" // for LLVM_VERSION_STRING
18 
19 namespace mlir {
20 class DialectBytecodeWriter;
21 class DialectVersion;
22 class Operation;
23 
24 /// A class to interact with the attributes and types printer when emitting MLIR
25 /// bytecode.
26 template <class T>
27 class AttrTypeBytecodeWriter {
28 public:
29   AttrTypeBytecodeWriter() = default;
30   virtual ~AttrTypeBytecodeWriter() = default;
31 
32   /// Callback writer API used in IRNumbering, where groups are created and
33   /// type/attribute components are numbered. At this stage, writer is expected
34   /// to be a `NumberingDialectWriter`.
35   virtual LogicalResult write(T entry, std::optional<StringRef> &name,
36                               DialectBytecodeWriter &writer) = 0;
37 
38   /// Callback writer API used in BytecodeWriter, where groups are created and
39   /// type/attribute components are numbered. Here, DialectBytecodeWriter is
40   /// expected to be an actual writer. The optional stringref specified by
41   /// the user is ignored, since the group was already specified when numbering
42   /// the IR.
43   LogicalResult write(T entry, DialectBytecodeWriter &writer) {
44     std::optional<StringRef> dummy;
45     return write(entry, dummy, writer);
46   }
47 
48   /// Return an Attribute/Type printer implemented via the given callable, whose
49   /// form should match that of the `write` function above.
50   template <typename CallableT,
51             std::enable_if_t<std::is_convertible_v<
52                                  CallableT, std::function<LogicalResult(
53                                                 T, std::optional<StringRef> &,
54                                                 DialectBytecodeWriter &)>>,
55                              bool> = true>
56   static std::unique_ptr<AttrTypeBytecodeWriter<T>>
57   fromCallable(CallableT &&writeFn) {
58     struct Processor : public AttrTypeBytecodeWriter<T> {
59       Processor(CallableT &&writeFn)
60           : AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {}
61       LogicalResult write(T entry, std::optional<StringRef> &name,
62                           DialectBytecodeWriter &writer) override {
63         return writeFn(entry, name, writer);
64       }
65 
66       std::decay_t<CallableT> writeFn;
67     };
68     return std::make_unique<Processor>(std::forward<CallableT>(writeFn));
69   }
70 };
71 
72 /// This class contains the configuration used for the bytecode writer. It
73 /// controls various aspects of bytecode generation, and contains all of the
74 /// various bytecode writer hooks.
75 class BytecodeWriterConfig {
76 public:
77   /// `producer` is an optional string that can be used to identify the producer
78   /// of the bytecode when reading. It has no functional effect on the bytecode
79   /// serialization.
80   BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING);
81   /// `map` is a fallback resource map, which when provided will attach resource
82   /// printers for the fallback resources within the map.
83   BytecodeWriterConfig(FallbackAsmResourceMap &map,
84                        StringRef producer = "MLIR" LLVM_VERSION_STRING);
85   ~BytecodeWriterConfig();
86 
87   /// An internal implementation class that contains the state of the
88   /// configuration.
89   struct Impl;
90 
91   /// Return an instance of the internal implementation.
92   const Impl &getImpl() const { return *impl; }
93 
94   /// Set the desired bytecode version to emit. This method does not validate
95   /// the desired version. The bytecode writer entry point will return failure
96   /// if it cannot emit the desired version.
97   void setDesiredBytecodeVersion(int64_t bytecodeVersion);
98 
99   /// Get the set desired bytecode version to emit.
100   int64_t getDesiredBytecodeVersion() const;
101 
102   /// A map containing the dialect versions to emit.
103   llvm::StringMap<std::unique_ptr<DialectVersion>> &
104   getDialectVersionMap() const;
105 
106   /// Set a given dialect version to emit on the map.
107   template <class T>
108   void setDialectVersion(std::unique_ptr<DialectVersion> dialectVersion) const {
109     return setDialectVersion(T::getDialectNamespace(),
110                              std::move(dialectVersion));
111   }
112   void setDialectVersion(StringRef dialectName,
113                          std::unique_ptr<DialectVersion> dialectVersion) const;
114 
115   //===--------------------------------------------------------------------===//
116   // Types and Attributes encoding
117   //===--------------------------------------------------------------------===//
118 
119   /// Retrieve the callbacks.
120   ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
121   getAttributeWriterCallbacks() const;
122   ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
123   getTypeWriterCallbacks() const;
124 
125   /// Attach a custom bytecode printer callback to the configuration for the
126   /// emission of custom type/attributes encodings.
127   void attachAttributeCallback(
128       std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback);
129   void
130   attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback);
131 
132   /// Attach a custom bytecode printer callback to the configuration for the
133   /// emission of custom type/attributes encodings.
134   template <typename CallableT>
135   std::enable_if_t<std::is_convertible_v<
136       CallableT,
137       std::function<LogicalResult(Attribute, std::optional<StringRef> &,
138                                   DialectBytecodeWriter &)>>>
139   attachAttributeCallback(CallableT &&emitFn) {
140     attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable(
141         std::forward<CallableT>(emitFn)));
142   }
143   template <typename CallableT>
144   std::enable_if_t<std::is_convertible_v<
145       CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &,
146                                              DialectBytecodeWriter &)>>>
147   attachTypeCallback(CallableT &&emitFn) {
148     attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable(
149         std::forward<CallableT>(emitFn)));
150   }
151 
152   //===--------------------------------------------------------------------===//
153   // Resources
154   //===--------------------------------------------------------------------===//
155 
156   /// Set a boolean flag to skip emission of resources into the bytecode file.
157   void setElideResourceDataFlag(bool shouldElideResourceData = true);
158 
159   /// Attach the given resource printer to the writer configuration.
160   void attachResourcePrinter(std::unique_ptr<AsmResourcePrinter> printer);
161 
162   /// Attach an resource printer, in the form of a callable, to the
163   /// configuration.
164   template <typename CallableT>
165   std::enable_if_t<std::is_convertible<
166       CallableT, function_ref<void(Operation *, AsmResourceBuilder &)>>::value>
167   attachResourcePrinter(StringRef name, CallableT &&printFn) {
168     attachResourcePrinter(AsmResourcePrinter::fromCallable(
169         name, std::forward<CallableT>(printFn)));
170   }
171 
172   /// Attach resource printers to the AsmState for the fallback resources
173   /// in the given map.
174   void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
175     for (auto &printer : map.getPrinters())
176       attachResourcePrinter(std::move(printer));
177   }
178 
179 private:
180   /// A pointer to allocated storage for the impl state.
181   std::unique_ptr<Impl> impl;
182 };
183 
184 //===----------------------------------------------------------------------===//
185 // Entry Points
186 //===----------------------------------------------------------------------===//
187 
188 /// Write the bytecode for the given operation to the provided output stream.
189 /// For streams where it matters, the given stream should be in "binary" mode.
190 /// It only ever fails if setDesiredByteCodeVersion can't be honored.
191 LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os,
192                                   const BytecodeWriterConfig &config = {});
193 
194 } // namespace mlir
195 
196 #endif // MLIR_BYTECODE_BYTECODEWRITER_H
197