xref: /llvm-project/mlir/lib/Bytecode/Writer/IRNumbering.h (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains various utilities that number IR structures in preparation
10 // for bytecode emission.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
15 #define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
16 
17 #include "mlir/IR/OpImplementation.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include <cstdint>
22 
23 namespace mlir {
24 class BytecodeDialectInterface;
25 class BytecodeWriterConfig;
26 
27 namespace bytecode {
28 namespace detail {
29 struct DialectNumbering;
30 
31 //===----------------------------------------------------------------------===//
32 // Attribute and Type Numbering
33 //===----------------------------------------------------------------------===//
34 
35 /// This class represents a numbering entry for an Attribute or Type.
36 struct AttrTypeNumbering {
37   AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {}
38 
39   /// The concrete value.
40   PointerUnion<Attribute, Type> value;
41 
42   /// The number assigned to this value.
43   unsigned number = 0;
44 
45   /// The number of references to this value.
46   unsigned refCount = 1;
47 
48   /// The dialect of this value.
49   DialectNumbering *dialect = nullptr;
50 };
51 struct AttributeNumbering : public AttrTypeNumbering {
52   AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
53   Attribute getValue() const { return cast<Attribute>(value); }
54 };
55 struct TypeNumbering : public AttrTypeNumbering {
56   TypeNumbering(Type value) : AttrTypeNumbering(value) {}
57   Type getValue() const { return cast<Type>(value); }
58 };
59 
60 //===----------------------------------------------------------------------===//
61 // OpName Numbering
62 //===----------------------------------------------------------------------===//
63 
64 /// This class represents the numbering entry of an operation name.
65 struct OpNameNumbering {
66   OpNameNumbering(DialectNumbering *dialect, OperationName name)
67       : dialect(dialect), name(name) {}
68 
69   /// The dialect of this value.
70   DialectNumbering *dialect;
71 
72   /// The concrete name.
73   OperationName name;
74 
75   /// The number assigned to this name.
76   unsigned number = 0;
77 
78   /// The number of references to this name.
79   unsigned refCount = 1;
80 };
81 
82 //===----------------------------------------------------------------------===//
83 // Dialect Resource Numbering
84 //===----------------------------------------------------------------------===//
85 
86 /// This class represents a numbering entry for a dialect resource.
87 struct DialectResourceNumbering {
88   DialectResourceNumbering(std::string key) : key(std::move(key)) {}
89 
90   /// The key used to reference this resource.
91   std::string key;
92 
93   /// The number assigned to this resource.
94   unsigned number = 0;
95 
96   /// A flag indicating if this resource is only a declaration, not a full
97   /// definition.
98   bool isDeclaration = true;
99 };
100 
101 //===----------------------------------------------------------------------===//
102 // Dialect Numbering
103 //===----------------------------------------------------------------------===//
104 
105 /// This class represents a numbering entry for an Dialect.
106 struct DialectNumbering {
107   DialectNumbering(StringRef name, unsigned number)
108       : name(name), number(number) {}
109 
110   /// The namespace of the dialect.
111   StringRef name;
112 
113   /// The number assigned to the dialect.
114   unsigned number;
115 
116   /// The bytecode dialect interface of the dialect if defined.
117   const BytecodeDialectInterface *interface = nullptr;
118 
119   /// The asm dialect interface of the dialect if defined.
120   const OpAsmDialectInterface *asmInterface = nullptr;
121 
122   /// The referenced resources of this dialect.
123   SetVector<AsmDialectResourceHandle> resources;
124 
125   /// A mapping from resource key to the corresponding resource numbering entry.
126   llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
127 };
128 
129 //===----------------------------------------------------------------------===//
130 // Operation Numbering
131 //===----------------------------------------------------------------------===//
132 
133 /// This class represents the numbering entry of an operation.
134 struct OperationNumbering {
135   OperationNumbering(unsigned number) : number(number) {}
136 
137   /// The number assigned to this operation.
138   unsigned number;
139 
140   /// A flag indicating if this operation's regions are isolated. If unset, the
141   /// operation isn't yet known to be isolated.
142   std::optional<bool> isIsolatedFromAbove;
143 };
144 
145 //===----------------------------------------------------------------------===//
146 // IRNumberingState
147 //===----------------------------------------------------------------------===//
148 
149 /// This class manages numbering IR entities in preparation of bytecode
150 /// emission.
151 class IRNumberingState {
152 public:
153   IRNumberingState(Operation *op, const BytecodeWriterConfig &config);
154 
155   /// Return the numbered dialects.
156   auto getDialects() {
157     return llvm::make_pointee_range(llvm::make_second_range(dialects));
158   }
159   auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); }
160   auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); }
161   auto getTypes() { return llvm::make_pointee_range(orderedTypes); }
162 
163   /// Return the number for the given IR unit.
164   unsigned getNumber(Attribute attr) {
165     assert(attrs.count(attr) && "attribute not numbered");
166     return attrs[attr]->number;
167   }
168   unsigned getNumber(Block *block) {
169     assert(blockIDs.count(block) && "block not numbered");
170     return blockIDs[block];
171   }
172   unsigned getNumber(Operation *op) {
173     assert(operations.count(op) && "operation not numbered");
174     return operations[op]->number;
175   }
176   unsigned getNumber(OperationName opName) {
177     assert(opNames.count(opName) && "opName not numbered");
178     return opNames[opName]->number;
179   }
180   unsigned getNumber(Type type) {
181     assert(types.count(type) && "type not numbered");
182     return types[type]->number;
183   }
184   unsigned getNumber(Value value) {
185     assert(valueIDs.count(value) && "value not numbered");
186     return valueIDs[value];
187   }
188   unsigned getNumber(const AsmDialectResourceHandle &resource) {
189     assert(dialectResources.count(resource) && "resource not numbered");
190     return dialectResources[resource]->number;
191   }
192 
193   /// Return the block and value counts of the given region.
194   std::pair<unsigned, unsigned> getBlockValueCount(Region *region) {
195     assert(regionBlockValueCounts.count(region) && "value not numbered");
196     return regionBlockValueCounts[region];
197   }
198 
199   /// Return the number of operations in the given block.
200   unsigned getOperationCount(Block *block) {
201     assert(blockOperationCounts.count(block) && "block not numbered");
202     return blockOperationCounts[block];
203   }
204 
205   /// Return if the given operation is isolated from above.
206   bool isIsolatedFromAbove(Operation *op) {
207     assert(operations.count(op) && "operation not numbered");
208     return operations[op]->isIsolatedFromAbove.value_or(false);
209   }
210 
211   /// Get the set desired bytecode version to emit.
212   int64_t getDesiredBytecodeVersion() const;
213 
214 private:
215   /// This class is used to provide a fake dialect writer for numbering nested
216   /// attributes and types.
217   struct NumberingDialectWriter;
218 
219   /// Compute the global numbering state for the given root operation.
220   void computeGlobalNumberingState(Operation *rootOp);
221 
222   /// Number the given IR unit for bytecode emission.
223   void number(Attribute attr);
224   void number(Block &block);
225   DialectNumbering &numberDialect(Dialect *dialect);
226   DialectNumbering &numberDialect(StringRef dialect);
227   void number(Operation &op);
228   void number(OperationName opName);
229   void number(Region &region);
230   void number(Type type);
231 
232   /// Number the given dialect resources.
233   void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources);
234 
235   /// Finalize the numberings of any dialect resources.
236   void finalizeDialectResourceNumberings(Operation *rootOp);
237 
238   /// Mapping from IR to the respective numbering entries.
239   DenseMap<Attribute, AttributeNumbering *> attrs;
240   DenseMap<Operation *, OperationNumbering *> operations;
241   DenseMap<OperationName, OpNameNumbering *> opNames;
242   DenseMap<Type, TypeNumbering *> types;
243   DenseMap<Dialect *, DialectNumbering *> registeredDialects;
244   llvm::MapVector<StringRef, DialectNumbering *> dialects;
245   std::vector<AttributeNumbering *> orderedAttrs;
246   std::vector<OpNameNumbering *> orderedOpNames;
247   std::vector<TypeNumbering *> orderedTypes;
248 
249   /// A mapping from dialect resource handle to the numbering for the referenced
250   /// resource.
251   llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *>
252       dialectResources;
253 
254   /// Allocators used for the various numbering entries.
255   llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
256   llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
257   llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
258   llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
259   llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
260   llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
261 
262   /// The value ID for each Block and Value.
263   DenseMap<Block *, unsigned> blockIDs;
264   DenseMap<Value, unsigned> valueIDs;
265 
266   /// The number of operations in each block.
267   DenseMap<Block *, unsigned> blockOperationCounts;
268 
269   /// A map from region to the number of blocks and values within that region.
270   DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts;
271 
272   /// The next value ID to assign when numbering.
273   unsigned nextValueID = 0;
274 
275   // Configuration: useful to query the required version to emit.
276   const BytecodeWriterConfig &config;
277 };
278 } // namespace detail
279 } // namespace bytecode
280 } // namespace mlir
281 
282 #endif
283