1 //===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header defines interfaces to write MLIR bytecode files/streams. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_BYTECODE_BYTECODEWRITER_H 14 #define MLIR_BYTECODE_BYTECODEWRITER_H 15 16 #include "mlir/IR/AsmState.h" 17 #include "llvm/Config/llvm-config.h" // for LLVM_VERSION_STRING 18 19 namespace mlir { 20 class DialectBytecodeWriter; 21 class DialectVersion; 22 class Operation; 23 24 /// A class to interact with the attributes and types printer when emitting MLIR 25 /// bytecode. 26 template <class T> 27 class AttrTypeBytecodeWriter { 28 public: 29 AttrTypeBytecodeWriter() = default; 30 virtual ~AttrTypeBytecodeWriter() = default; 31 32 /// Callback writer API used in IRNumbering, where groups are created and 33 /// type/attribute components are numbered. At this stage, writer is expected 34 /// to be a `NumberingDialectWriter`. 35 virtual LogicalResult write(T entry, std::optional<StringRef> &name, 36 DialectBytecodeWriter &writer) = 0; 37 38 /// Callback writer API used in BytecodeWriter, where groups are created and 39 /// type/attribute components are numbered. Here, DialectBytecodeWriter is 40 /// expected to be an actual writer. The optional stringref specified by 41 /// the user is ignored, since the group was already specified when numbering 42 /// the IR. 43 LogicalResult write(T entry, DialectBytecodeWriter &writer) { 44 std::optional<StringRef> dummy; 45 return write(entry, dummy, writer); 46 } 47 48 /// Return an Attribute/Type printer implemented via the given callable, whose 49 /// form should match that of the `write` function above. 50 template <typename CallableT, 51 std::enable_if_t<std::is_convertible_v< 52 CallableT, std::function<LogicalResult( 53 T, std::optional<StringRef> &, 54 DialectBytecodeWriter &)>>, 55 bool> = true> 56 static std::unique_ptr<AttrTypeBytecodeWriter<T>> 57 fromCallable(CallableT &&writeFn) { 58 struct Processor : public AttrTypeBytecodeWriter<T> { 59 Processor(CallableT &&writeFn) 60 : AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {} 61 LogicalResult write(T entry, std::optional<StringRef> &name, 62 DialectBytecodeWriter &writer) override { 63 return writeFn(entry, name, writer); 64 } 65 66 std::decay_t<CallableT> writeFn; 67 }; 68 return std::make_unique<Processor>(std::forward<CallableT>(writeFn)); 69 } 70 }; 71 72 /// This class contains the configuration used for the bytecode writer. It 73 /// controls various aspects of bytecode generation, and contains all of the 74 /// various bytecode writer hooks. 75 class BytecodeWriterConfig { 76 public: 77 /// `producer` is an optional string that can be used to identify the producer 78 /// of the bytecode when reading. It has no functional effect on the bytecode 79 /// serialization. 80 BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING); 81 /// `map` is a fallback resource map, which when provided will attach resource 82 /// printers for the fallback resources within the map. 83 BytecodeWriterConfig(FallbackAsmResourceMap &map, 84 StringRef producer = "MLIR" LLVM_VERSION_STRING); 85 ~BytecodeWriterConfig(); 86 87 /// An internal implementation class that contains the state of the 88 /// configuration. 89 struct Impl; 90 91 /// Return an instance of the internal implementation. 92 const Impl &getImpl() const { return *impl; } 93 94 /// Set the desired bytecode version to emit. This method does not validate 95 /// the desired version. The bytecode writer entry point will return failure 96 /// if it cannot emit the desired version. 97 void setDesiredBytecodeVersion(int64_t bytecodeVersion); 98 99 /// Get the set desired bytecode version to emit. 100 int64_t getDesiredBytecodeVersion() const; 101 102 /// A map containing the dialect versions to emit. 103 llvm::StringMap<std::unique_ptr<DialectVersion>> & 104 getDialectVersionMap() const; 105 106 /// Set a given dialect version to emit on the map. 107 template <class T> 108 void setDialectVersion(std::unique_ptr<DialectVersion> dialectVersion) const { 109 return setDialectVersion(T::getDialectNamespace(), 110 std::move(dialectVersion)); 111 } 112 void setDialectVersion(StringRef dialectName, 113 std::unique_ptr<DialectVersion> dialectVersion) const; 114 115 //===--------------------------------------------------------------------===// 116 // Types and Attributes encoding 117 //===--------------------------------------------------------------------===// 118 119 /// Retrieve the callbacks. 120 ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> 121 getAttributeWriterCallbacks() const; 122 ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> 123 getTypeWriterCallbacks() const; 124 125 /// Attach a custom bytecode printer callback to the configuration for the 126 /// emission of custom type/attributes encodings. 127 void attachAttributeCallback( 128 std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback); 129 void 130 attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback); 131 132 /// Attach a custom bytecode printer callback to the configuration for the 133 /// emission of custom type/attributes encodings. 134 template <typename CallableT> 135 std::enable_if_t<std::is_convertible_v< 136 CallableT, 137 std::function<LogicalResult(Attribute, std::optional<StringRef> &, 138 DialectBytecodeWriter &)>>> 139 attachAttributeCallback(CallableT &&emitFn) { 140 attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable( 141 std::forward<CallableT>(emitFn))); 142 } 143 template <typename CallableT> 144 std::enable_if_t<std::is_convertible_v< 145 CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &, 146 DialectBytecodeWriter &)>>> 147 attachTypeCallback(CallableT &&emitFn) { 148 attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable( 149 std::forward<CallableT>(emitFn))); 150 } 151 152 //===--------------------------------------------------------------------===// 153 // Resources 154 //===--------------------------------------------------------------------===// 155 156 /// Set a boolean flag to skip emission of resources into the bytecode file. 157 void setElideResourceDataFlag(bool shouldElideResourceData = true); 158 159 /// Attach the given resource printer to the writer configuration. 160 void attachResourcePrinter(std::unique_ptr<AsmResourcePrinter> printer); 161 162 /// Attach an resource printer, in the form of a callable, to the 163 /// configuration. 164 template <typename CallableT> 165 std::enable_if_t<std::is_convertible< 166 CallableT, function_ref<void(Operation *, AsmResourceBuilder &)>>::value> 167 attachResourcePrinter(StringRef name, CallableT &&printFn) { 168 attachResourcePrinter(AsmResourcePrinter::fromCallable( 169 name, std::forward<CallableT>(printFn))); 170 } 171 172 /// Attach resource printers to the AsmState for the fallback resources 173 /// in the given map. 174 void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) { 175 for (auto &printer : map.getPrinters()) 176 attachResourcePrinter(std::move(printer)); 177 } 178 179 private: 180 /// A pointer to allocated storage for the impl state. 181 std::unique_ptr<Impl> impl; 182 }; 183 184 //===----------------------------------------------------------------------===// 185 // Entry Points 186 //===----------------------------------------------------------------------===// 187 188 /// Write the bytecode for the given operation to the provided output stream. 189 /// For streams where it matters, the given stream should be in "binary" mode. 190 /// It only ever fails if setDesiredByteCodeVersion can't be honored. 191 LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os, 192 const BytecodeWriterConfig &config = {}); 193 194 } // namespace mlir 195 196 #endif // MLIR_BYTECODE_BYTECODEWRITER_H 197