1 //===- BytecodeReaderConfig.h - MLIR Bytecode Reader Config -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header config for reading MLIR bytecode files/streams. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H 14 #define MLIR_BYTECODE_BYTECODEREADERCONFIG_H 15 16 #include "mlir/Support/LLVM.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringRef.h" 20 21 namespace mlir { 22 class Attribute; 23 class DialectBytecodeReader; 24 class Type; 25 26 /// A class to interact with the attributes and types parser when parsing MLIR 27 /// bytecode. 28 template <class T> 29 class AttrTypeBytecodeReader { 30 public: 31 AttrTypeBytecodeReader() = default; 32 virtual ~AttrTypeBytecodeReader() = default; 33 34 virtual LogicalResult read(DialectBytecodeReader &reader, 35 StringRef dialectName, T &entry) = 0; 36 37 /// Return an Attribute/Type printer implemented via the given callable, whose 38 /// form should match that of the `parse` function above. 39 template <typename CallableT, 40 std::enable_if_t< 41 std::is_convertible_v< 42 CallableT, std::function<LogicalResult( 43 DialectBytecodeReader &, StringRef, T &)>>, 44 bool> = true> 45 static std::unique_ptr<AttrTypeBytecodeReader<T>> 46 fromCallable(CallableT &&readFn) { 47 struct Processor : public AttrTypeBytecodeReader<T> { 48 Processor(CallableT &&readFn) 49 : AttrTypeBytecodeReader(), readFn(std::move(readFn)) {} 50 LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName, 51 T &entry) override { 52 return readFn(reader, dialectName, entry); 53 } 54 55 std::decay_t<CallableT> readFn; 56 }; 57 return std::make_unique<Processor>(std::forward<CallableT>(readFn)); 58 } 59 }; 60 61 //===----------------------------------------------------------------------===// 62 // BytecodeReaderConfig 63 //===----------------------------------------------------------------------===// 64 65 /// A class containing bytecode-specific configurations of the `ParserConfig`. 66 class BytecodeReaderConfig { 67 public: 68 BytecodeReaderConfig() = default; 69 70 /// Returns the callbacks available to the parser. 71 ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>> 72 getAttributeCallbacks() const { 73 return attributeBytecodeParsers; 74 } 75 ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>> 76 getTypeCallbacks() const { 77 return typeBytecodeParsers; 78 } 79 80 /// Attach a custom bytecode parser callback to the configuration for parsing 81 /// of custom type/attributes encodings. 82 void attachAttributeCallback( 83 std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) { 84 attributeBytecodeParsers.emplace_back(std::move(parser)); 85 } 86 void 87 attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) { 88 typeBytecodeParsers.emplace_back(std::move(parser)); 89 } 90 91 /// Attach a custom bytecode parser callback to the configuration for parsing 92 /// of custom type/attributes encodings. 93 template <typename CallableT> 94 std::enable_if_t<std::is_convertible_v< 95 CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef, 96 Attribute &)>>> 97 attachAttributeCallback(CallableT &&parserFn) { 98 attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable( 99 std::forward<CallableT>(parserFn))); 100 } 101 template <typename CallableT> 102 std::enable_if_t<std::is_convertible_v< 103 CallableT, 104 std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>> 105 attachTypeCallback(CallableT &&parserFn) { 106 attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable( 107 std::forward<CallableT>(parserFn))); 108 } 109 110 private: 111 llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>> 112 attributeBytecodeParsers; 113 llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>> 114 typeBytecodeParsers; 115 }; 116 117 } // namespace mlir 118 119 #endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H 120