1 //===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains various utilities that number IR structures in preparation 10 // for bytecode emission. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H 15 #define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H 16 17 #include "mlir/IR/OpImplementation.h" 18 #include "llvm/ADT/MapVector.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringMap.h" 21 #include <cstdint> 22 23 namespace mlir { 24 class BytecodeDialectInterface; 25 class BytecodeWriterConfig; 26 27 namespace bytecode { 28 namespace detail { 29 struct DialectNumbering; 30 31 //===----------------------------------------------------------------------===// 32 // Attribute and Type Numbering 33 //===----------------------------------------------------------------------===// 34 35 /// This class represents a numbering entry for an Attribute or Type. 36 struct AttrTypeNumbering { 37 AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {} 38 39 /// The concrete value. 40 PointerUnion<Attribute, Type> value; 41 42 /// The number assigned to this value. 43 unsigned number = 0; 44 45 /// The number of references to this value. 46 unsigned refCount = 1; 47 48 /// The dialect of this value. 49 DialectNumbering *dialect = nullptr; 50 }; 51 struct AttributeNumbering : public AttrTypeNumbering { 52 AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {} 53 Attribute getValue() const { return cast<Attribute>(value); } 54 }; 55 struct TypeNumbering : public AttrTypeNumbering { 56 TypeNumbering(Type value) : AttrTypeNumbering(value) {} 57 Type getValue() const { return cast<Type>(value); } 58 }; 59 60 //===----------------------------------------------------------------------===// 61 // OpName Numbering 62 //===----------------------------------------------------------------------===// 63 64 /// This class represents the numbering entry of an operation name. 65 struct OpNameNumbering { 66 OpNameNumbering(DialectNumbering *dialect, OperationName name) 67 : dialect(dialect), name(name) {} 68 69 /// The dialect of this value. 70 DialectNumbering *dialect; 71 72 /// The concrete name. 73 OperationName name; 74 75 /// The number assigned to this name. 76 unsigned number = 0; 77 78 /// The number of references to this name. 79 unsigned refCount = 1; 80 }; 81 82 //===----------------------------------------------------------------------===// 83 // Dialect Resource Numbering 84 //===----------------------------------------------------------------------===// 85 86 /// This class represents a numbering entry for a dialect resource. 87 struct DialectResourceNumbering { 88 DialectResourceNumbering(std::string key) : key(std::move(key)) {} 89 90 /// The key used to reference this resource. 91 std::string key; 92 93 /// The number assigned to this resource. 94 unsigned number = 0; 95 96 /// A flag indicating if this resource is only a declaration, not a full 97 /// definition. 98 bool isDeclaration = true; 99 }; 100 101 //===----------------------------------------------------------------------===// 102 // Dialect Numbering 103 //===----------------------------------------------------------------------===// 104 105 /// This class represents a numbering entry for an Dialect. 106 struct DialectNumbering { 107 DialectNumbering(StringRef name, unsigned number) 108 : name(name), number(number) {} 109 110 /// The namespace of the dialect. 111 StringRef name; 112 113 /// The number assigned to the dialect. 114 unsigned number; 115 116 /// The bytecode dialect interface of the dialect if defined. 117 const BytecodeDialectInterface *interface = nullptr; 118 119 /// The asm dialect interface of the dialect if defined. 120 const OpAsmDialectInterface *asmInterface = nullptr; 121 122 /// The referenced resources of this dialect. 123 SetVector<AsmDialectResourceHandle> resources; 124 125 /// A mapping from resource key to the corresponding resource numbering entry. 126 llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap; 127 }; 128 129 //===----------------------------------------------------------------------===// 130 // Operation Numbering 131 //===----------------------------------------------------------------------===// 132 133 /// This class represents the numbering entry of an operation. 134 struct OperationNumbering { 135 OperationNumbering(unsigned number) : number(number) {} 136 137 /// The number assigned to this operation. 138 unsigned number; 139 140 /// A flag indicating if this operation's regions are isolated. If unset, the 141 /// operation isn't yet known to be isolated. 142 std::optional<bool> isIsolatedFromAbove; 143 }; 144 145 //===----------------------------------------------------------------------===// 146 // IRNumberingState 147 //===----------------------------------------------------------------------===// 148 149 /// This class manages numbering IR entities in preparation of bytecode 150 /// emission. 151 class IRNumberingState { 152 public: 153 IRNumberingState(Operation *op, const BytecodeWriterConfig &config); 154 155 /// Return the numbered dialects. 156 auto getDialects() { 157 return llvm::make_pointee_range(llvm::make_second_range(dialects)); 158 } 159 auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); } 160 auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); } 161 auto getTypes() { return llvm::make_pointee_range(orderedTypes); } 162 163 /// Return the number for the given IR unit. 164 unsigned getNumber(Attribute attr) { 165 assert(attrs.count(attr) && "attribute not numbered"); 166 return attrs[attr]->number; 167 } 168 unsigned getNumber(Block *block) { 169 assert(blockIDs.count(block) && "block not numbered"); 170 return blockIDs[block]; 171 } 172 unsigned getNumber(Operation *op) { 173 assert(operations.count(op) && "operation not numbered"); 174 return operations[op]->number; 175 } 176 unsigned getNumber(OperationName opName) { 177 assert(opNames.count(opName) && "opName not numbered"); 178 return opNames[opName]->number; 179 } 180 unsigned getNumber(Type type) { 181 assert(types.count(type) && "type not numbered"); 182 return types[type]->number; 183 } 184 unsigned getNumber(Value value) { 185 assert(valueIDs.count(value) && "value not numbered"); 186 return valueIDs[value]; 187 } 188 unsigned getNumber(const AsmDialectResourceHandle &resource) { 189 assert(dialectResources.count(resource) && "resource not numbered"); 190 return dialectResources[resource]->number; 191 } 192 193 /// Return the block and value counts of the given region. 194 std::pair<unsigned, unsigned> getBlockValueCount(Region *region) { 195 assert(regionBlockValueCounts.count(region) && "value not numbered"); 196 return regionBlockValueCounts[region]; 197 } 198 199 /// Return the number of operations in the given block. 200 unsigned getOperationCount(Block *block) { 201 assert(blockOperationCounts.count(block) && "block not numbered"); 202 return blockOperationCounts[block]; 203 } 204 205 /// Return if the given operation is isolated from above. 206 bool isIsolatedFromAbove(Operation *op) { 207 assert(operations.count(op) && "operation not numbered"); 208 return operations[op]->isIsolatedFromAbove.value_or(false); 209 } 210 211 /// Get the set desired bytecode version to emit. 212 int64_t getDesiredBytecodeVersion() const; 213 214 private: 215 /// This class is used to provide a fake dialect writer for numbering nested 216 /// attributes and types. 217 struct NumberingDialectWriter; 218 219 /// Compute the global numbering state for the given root operation. 220 void computeGlobalNumberingState(Operation *rootOp); 221 222 /// Number the given IR unit for bytecode emission. 223 void number(Attribute attr); 224 void number(Block &block); 225 DialectNumbering &numberDialect(Dialect *dialect); 226 DialectNumbering &numberDialect(StringRef dialect); 227 void number(Operation &op); 228 void number(OperationName opName); 229 void number(Region ®ion); 230 void number(Type type); 231 232 /// Number the given dialect resources. 233 void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources); 234 235 /// Finalize the numberings of any dialect resources. 236 void finalizeDialectResourceNumberings(Operation *rootOp); 237 238 /// Mapping from IR to the respective numbering entries. 239 DenseMap<Attribute, AttributeNumbering *> attrs; 240 DenseMap<Operation *, OperationNumbering *> operations; 241 DenseMap<OperationName, OpNameNumbering *> opNames; 242 DenseMap<Type, TypeNumbering *> types; 243 DenseMap<Dialect *, DialectNumbering *> registeredDialects; 244 llvm::MapVector<StringRef, DialectNumbering *> dialects; 245 std::vector<AttributeNumbering *> orderedAttrs; 246 std::vector<OpNameNumbering *> orderedOpNames; 247 std::vector<TypeNumbering *> orderedTypes; 248 249 /// A mapping from dialect resource handle to the numbering for the referenced 250 /// resource. 251 llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *> 252 dialectResources; 253 254 /// Allocators used for the various numbering entries. 255 llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator; 256 llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator; 257 llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator; 258 llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator; 259 llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator; 260 llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator; 261 262 /// The value ID for each Block and Value. 263 DenseMap<Block *, unsigned> blockIDs; 264 DenseMap<Value, unsigned> valueIDs; 265 266 /// The number of operations in each block. 267 DenseMap<Block *, unsigned> blockOperationCounts; 268 269 /// A map from region to the number of blocks and values within that region. 270 DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts; 271 272 /// The next value ID to assign when numbering. 273 unsigned nextValueID = 0; 274 275 // Configuration: useful to query the required version to emit. 276 const BytecodeWriterConfig &config; 277 }; 278 } // namespace detail 279 } // namespace bytecode 280 } // namespace mlir 281 282 #endif 283