xref: /llvm-project/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h (revision f1ac7725e4fd5afa21fb244f9bcc33de654ed80c)
1 //===- BytecodeReaderConfig.h - MLIR Bytecode Reader Config -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header config for reading MLIR bytecode files/streams.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H
14 #define MLIR_BYTECODE_BYTECODEREADERCONFIG_H
15 
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 namespace mlir {
22 class Attribute;
23 class DialectBytecodeReader;
24 class Type;
25 
26 /// A class to interact with the attributes and types parser when parsing MLIR
27 /// bytecode.
28 template <class T>
29 class AttrTypeBytecodeReader {
30 public:
31   AttrTypeBytecodeReader() = default;
32   virtual ~AttrTypeBytecodeReader() = default;
33 
34   virtual LogicalResult read(DialectBytecodeReader &reader,
35                              StringRef dialectName, T &entry) = 0;
36 
37   /// Return an Attribute/Type printer implemented via the given callable, whose
38   /// form should match that of the `parse` function above.
39   template <typename CallableT,
40             std::enable_if_t<
41                 std::is_convertible_v<
42                     CallableT, std::function<LogicalResult(
43                                    DialectBytecodeReader &, StringRef, T &)>>,
44                 bool> = true>
45   static std::unique_ptr<AttrTypeBytecodeReader<T>>
46   fromCallable(CallableT &&readFn) {
47     struct Processor : public AttrTypeBytecodeReader<T> {
48       Processor(CallableT &&readFn)
49           : AttrTypeBytecodeReader(), readFn(std::move(readFn)) {}
50       LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName,
51                          T &entry) override {
52         return readFn(reader, dialectName, entry);
53       }
54 
55       std::decay_t<CallableT> readFn;
56     };
57     return std::make_unique<Processor>(std::forward<CallableT>(readFn));
58   }
59 };
60 
61 //===----------------------------------------------------------------------===//
62 // BytecodeReaderConfig
63 //===----------------------------------------------------------------------===//
64 
65 /// A class containing bytecode-specific configurations of the `ParserConfig`.
66 class BytecodeReaderConfig {
67 public:
68   BytecodeReaderConfig() = default;
69 
70   /// Returns the callbacks available to the parser.
71   ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
72   getAttributeCallbacks() const {
73     return attributeBytecodeParsers;
74   }
75   ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
76   getTypeCallbacks() const {
77     return typeBytecodeParsers;
78   }
79 
80   /// Attach a custom bytecode parser callback to the configuration for parsing
81   /// of custom type/attributes encodings.
82   void attachAttributeCallback(
83       std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
84     attributeBytecodeParsers.emplace_back(std::move(parser));
85   }
86   void
87   attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
88     typeBytecodeParsers.emplace_back(std::move(parser));
89   }
90 
91   /// Attach a custom bytecode parser callback to the configuration for parsing
92   /// of custom type/attributes encodings.
93   template <typename CallableT>
94   std::enable_if_t<std::is_convertible_v<
95       CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef,
96                                              Attribute &)>>>
97   attachAttributeCallback(CallableT &&parserFn) {
98     attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable(
99         std::forward<CallableT>(parserFn)));
100   }
101   template <typename CallableT>
102   std::enable_if_t<std::is_convertible_v<
103       CallableT,
104       std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>>
105   attachTypeCallback(CallableT &&parserFn) {
106     attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable(
107         std::forward<CallableT>(parserFn)));
108   }
109 
110 private:
111   llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
112       attributeBytecodeParsers;
113   llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
114       typeBytecodeParsers;
115 };
116 
117 } // namespace mlir
118 
119 #endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H
120