1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_SYMBOLTABLE_H 10 #define MLIR_IR_SYMBOLTABLE_H 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/OpDefinition.h" 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/StringMap.h" 16 #include "llvm/Support/RWMutex.h" 17 18 namespace mlir { 19 20 /// This class allows for representing and managing the symbol table used by 21 /// operations with the 'SymbolTable' trait. Inserting into and erasing from 22 /// this SymbolTable will also insert and erase from the Operation given to it 23 /// at construction. 24 class SymbolTable { 25 public: 26 /// Build a symbol table with the symbols within the given operation. 27 SymbolTable(Operation *symbolTableOp); 28 29 /// Look up a symbol with the specified name, returning null if no such 30 /// name exists. Names never include the @ on them. 31 Operation *lookup(StringRef name) const; 32 template <typename T> lookup(StringRef name)33 T lookup(StringRef name) const { 34 return dyn_cast_or_null<T>(lookup(name)); 35 } 36 37 /// Look up a symbol with the specified name, returning null if no such 38 /// name exists. Names never include the @ on them. 39 Operation *lookup(StringAttr name) const; 40 template <typename T> lookup(StringAttr name)41 T lookup(StringAttr name) const { 42 return dyn_cast_or_null<T>(lookup(name)); 43 } 44 45 /// Remove the given symbol from the table, without deleting it. 46 void remove(Operation *op); 47 48 /// Erase the given symbol from the table and delete the operation. 49 void erase(Operation *symbol); 50 51 /// Insert a new symbol into the table, and rename it as necessary to avoid 52 /// collisions. Also insert at the specified location in the body of the 53 /// associated operation if it is not already there. It is asserted that the 54 /// symbol is not inside another operation. Return the name of the symbol 55 /// after insertion as attribute. 56 StringAttr insert(Operation *symbol, Block::iterator insertPt = {}); 57 58 /// Renames the given op or the op refered to by the given name to the given 59 /// new name and updates the symbol table and all usages of the symbol 60 /// accordingly. Fails if the updating of the usages fails. 61 LogicalResult rename(StringAttr from, StringAttr to); 62 LogicalResult rename(Operation *op, StringAttr to); 63 LogicalResult rename(StringAttr from, StringRef to); 64 LogicalResult rename(Operation *op, StringRef to); 65 66 /// Renames the given op or the op refered to by the given name to the a name 67 /// that is unique within this and the provided other symbol tables and 68 /// updates the symbol table and all usages of the symbol accordingly. Returns 69 /// the new name or failure if the renaming fails. 70 FailureOr<StringAttr> renameToUnique(StringAttr from, 71 ArrayRef<SymbolTable *> others); 72 FailureOr<StringAttr> renameToUnique(Operation *op, 73 ArrayRef<SymbolTable *> others); 74 75 /// Return the name of the attribute used for symbol names. getSymbolAttrName()76 static StringRef getSymbolAttrName() { return "sym_name"; } 77 78 /// Returns the associated operation. getOp()79 Operation *getOp() const { return symbolTableOp; } 80 81 /// Return the name of the attribute used for symbol visibility. getVisibilityAttrName()82 static StringRef getVisibilityAttrName() { return "sym_visibility"; } 83 84 //===--------------------------------------------------------------------===// 85 // Symbol Utilities 86 //===--------------------------------------------------------------------===// 87 88 /// An enumeration detailing the different visibility types that a symbol may 89 /// have. 90 enum class Visibility { 91 /// The symbol is public and may be referenced anywhere internal or external 92 /// to the visible references in the IR. 93 Public, 94 95 /// The symbol is private and may only be referenced by SymbolRefAttrs local 96 /// to the operations within the current symbol table. 97 Private, 98 99 /// The symbol is visible to the current IR, which may include operations in 100 /// symbol tables above the one that owns the current symbol. `Nested` 101 /// visibility allows for referencing a symbol outside of its current symbol 102 /// table, while retaining the ability to observe all uses. 103 Nested, 104 }; 105 106 /// Generate a unique symbol name. Iteratively increase uniquingCounter 107 /// and use it as a suffix for symbol names until uniqueChecker does not 108 /// detect any conflict. 109 template <unsigned N, typename UniqueChecker> generateSymbolName(StringRef name,UniqueChecker uniqueChecker,unsigned & uniquingCounter)110 static SmallString<N> generateSymbolName(StringRef name, 111 UniqueChecker uniqueChecker, 112 unsigned &uniquingCounter) { 113 SmallString<N> nameBuffer(name); 114 unsigned originalLength = nameBuffer.size(); 115 do { 116 nameBuffer.resize(originalLength); 117 nameBuffer += '_'; 118 nameBuffer += std::to_string(uniquingCounter++); 119 } while (uniqueChecker(nameBuffer)); 120 121 return nameBuffer; 122 } 123 124 /// Returns the name of the given symbol operation, aborting if no symbol is 125 /// present. 126 static StringAttr getSymbolName(Operation *symbol); 127 128 /// Sets the name of the given symbol operation. 129 static void setSymbolName(Operation *symbol, StringAttr name); setSymbolName(Operation * symbol,StringRef name)130 static void setSymbolName(Operation *symbol, StringRef name) { 131 setSymbolName(symbol, StringAttr::get(symbol->getContext(), name)); 132 } 133 134 /// Returns the visibility of the given symbol operation. 135 static Visibility getSymbolVisibility(Operation *symbol); 136 /// Sets the visibility of the given symbol operation. 137 static void setSymbolVisibility(Operation *symbol, Visibility vis); 138 139 /// Returns the nearest symbol table from a given operation `from`. Returns 140 /// nullptr if no valid parent symbol table could be found. 141 static Operation *getNearestSymbolTable(Operation *from); 142 143 /// Walks all symbol table operations nested within, and including, `op`. For 144 /// each symbol table operation, the provided callback is invoked with the op 145 /// and a boolean signifying if the symbols within that symbol table can be 146 /// treated as if all uses within the IR are visible to the caller. 147 /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols 148 /// within `op` are visible. 149 static void walkSymbolTables(Operation *op, bool allSymUsesVisible, 150 function_ref<void(Operation *, bool)> callback); 151 152 /// Returns the operation registered with the given symbol name with the 153 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 154 /// with the 'OpTrait::SymbolTable' trait. 155 static Operation *lookupSymbolIn(Operation *op, StringAttr symbol); lookupSymbolIn(Operation * op,StringRef symbol)156 static Operation *lookupSymbolIn(Operation *op, StringRef symbol) { 157 return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol)); 158 } 159 static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); 160 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 161 /// by a given SymbolRefAttr. Returns failure if any of the nested references 162 /// could not be resolved. 163 static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, 164 SmallVectorImpl<Operation *> &symbols); 165 166 /// Returns the operation registered with the given symbol name within the 167 /// closest parent operation of, or including, 'from' with the 168 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 169 /// found. 170 static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); 171 static Operation *lookupNearestSymbolFrom(Operation *from, 172 SymbolRefAttr symbol); 173 template <typename T> lookupNearestSymbolFrom(Operation * from,StringAttr symbol)174 static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { 175 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 176 } 177 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)178 static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 179 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 180 } 181 182 /// This class represents a specific symbol use. 183 class SymbolUse { 184 public: SymbolUse(Operation * op,SymbolRefAttr symbolRef)185 SymbolUse(Operation *op, SymbolRefAttr symbolRef) 186 : owner(op), symbolRef(symbolRef) {} 187 188 /// Return the operation user of this symbol reference. getUser()189 Operation *getUser() const { return owner; } 190 191 /// Return the symbol reference that this use represents. getSymbolRef()192 SymbolRefAttr getSymbolRef() const { return symbolRef; } 193 194 private: 195 /// The operation that this access is held by. 196 Operation *owner; 197 198 /// The symbol reference that this use represents. 199 SymbolRefAttr symbolRef; 200 }; 201 202 /// This class implements a range of SymbolRef uses. 203 class UseRange { 204 public: UseRange(std::vector<SymbolUse> && uses)205 UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} 206 207 using iterator = std::vector<SymbolUse>::const_iterator; begin()208 iterator begin() const { return uses.begin(); } end()209 iterator end() const { return uses.end(); } empty()210 bool empty() const { return uses.empty(); } 211 212 private: 213 std::vector<SymbolUse> uses; 214 }; 215 216 /// Get an iterator range for all of the uses, for any symbol, that are nested 217 /// within the given operation 'from'. This does not traverse into any nested 218 /// symbol tables. This function returns std::nullopt if there are any unknown 219 /// operations that may potentially be symbol tables. 220 static std::optional<UseRange> getSymbolUses(Operation *from); 221 static std::optional<UseRange> getSymbolUses(Region *from); 222 223 /// Get all of the uses of the given symbol that are nested within the given 224 /// operation 'from'. This does not traverse into any nested symbol tables. 225 /// This function returns std::nullopt if there are any unknown operations 226 /// that may potentially be symbol tables. 227 static std::optional<UseRange> getSymbolUses(StringAttr symbol, 228 Operation *from); 229 static std::optional<UseRange> getSymbolUses(Operation *symbol, 230 Operation *from); 231 static std::optional<UseRange> getSymbolUses(StringAttr symbol, Region *from); 232 static std::optional<UseRange> getSymbolUses(Operation *symbol, Region *from); 233 234 /// Return if the given symbol is known to have no uses that are nested 235 /// within the given operation 'from'. This does not traverse into any nested 236 /// symbol tables. This function will also return false if there are any 237 /// unknown operations that may potentially be symbol tables. This doesn't 238 /// necessarily mean that there are no uses, we just can't conservatively 239 /// prove it. 240 static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from); 241 static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); 242 static bool symbolKnownUseEmpty(StringAttr symbol, Region *from); 243 static bool symbolKnownUseEmpty(Operation *symbol, Region *from); 244 245 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 246 /// provided symbol 'newSymbol' that are nested within the given operation 247 /// 'from'. This does not traverse into any nested symbol tables. If there are 248 /// any unknown operations that may potentially be symbol tables, no uses are 249 /// replaced and failure is returned. 250 static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, 251 StringAttr newSymbol, 252 Operation *from); 253 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, 254 StringAttr newSymbolName, 255 Operation *from); 256 static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, 257 StringAttr newSymbol, Region *from); 258 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, 259 StringAttr newSymbolName, 260 Region *from); 261 262 private: 263 Operation *symbolTableOp; 264 265 /// This is a mapping from a name to the symbol with that name. They key is 266 /// always known to be a StringAttr. 267 DenseMap<Attribute, Operation *> symbolTable; 268 269 /// This is used when name conflicts are detected. 270 unsigned uniquingCounter = 0; 271 }; 272 273 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility); 274 275 //===----------------------------------------------------------------------===// 276 // SymbolTableCollection 277 //===----------------------------------------------------------------------===// 278 279 /// This class represents a collection of `SymbolTable`s. This simplifies 280 /// certain algorithms that run recursively on nested symbol tables. Symbol 281 /// tables are constructed lazily to reduce the upfront cost of constructing 282 /// unnecessary tables. 283 class SymbolTableCollection { 284 public: 285 /// Look up a symbol with the specified name within the specified symbol table 286 /// operation, returning null if no such name exists. 287 Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); 288 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); 289 template <typename T, typename NameT> lookupSymbolIn(Operation * symbolTableOp,NameT && name)290 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) { 291 return dyn_cast_or_null<T>( 292 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); 293 } 294 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 295 /// by a given SymbolRefAttr when resolved within the provided symbol table 296 /// operation. Returns failure if any of the nested references could not be 297 /// resolved. 298 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, 299 SmallVectorImpl<Operation *> &symbols); 300 301 /// Returns the operation registered with the given symbol name within the 302 /// closest parent operation of, or including, 'from' with the 303 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 304 /// found. 305 Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); 306 Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); 307 template <typename T> lookupNearestSymbolFrom(Operation * from,StringAttr symbol)308 T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { 309 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 310 } 311 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)312 T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 313 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 314 } 315 316 /// Lookup, or create, a symbol table for an operation. 317 SymbolTable &getSymbolTable(Operation *op); 318 319 private: 320 friend class LockedSymbolTableCollection; 321 322 /// The constructed symbol tables nested within this table. 323 DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables; 324 }; 325 326 //===----------------------------------------------------------------------===// 327 // LockedSymbolTableCollection 328 //===----------------------------------------------------------------------===// 329 330 /// This class implements a lock-based shared wrapper around a symbol table 331 /// collection that allows shared access to the collection of symbol tables. 332 /// This class does not protect shared access to individual symbol tables. 333 /// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for 334 /// symbol table operations, making read operations not thread-safe. This class 335 /// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the 336 /// lazy `SymbolTable` lookup. 337 class LockedSymbolTableCollection : public SymbolTableCollection { 338 public: LockedSymbolTableCollection(SymbolTableCollection & collection)339 explicit LockedSymbolTableCollection(SymbolTableCollection &collection) 340 : collection(collection) {} 341 342 /// Look up a symbol with the specified name within the specified symbol table 343 /// operation, returning null if no such name exists. 344 Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); 345 /// Look up a symbol with the specified name within the specified symbol table 346 /// operation, returning null if no such name exists. 347 Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol); 348 /// Look up a potentially nested symbol within the specified symbol table 349 /// operation, returning null if no such symbol exists. 350 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); 351 352 /// Lookup a symbol of a particular kind within the specified symbol table, 353 /// returning null if the symbol was not found. 354 template <typename T, typename NameT> lookupSymbolIn(Operation * symbolTableOp,NameT && name)355 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) { 356 return dyn_cast_or_null<T>( 357 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); 358 } 359 360 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 361 /// by a given SymbolRefAttr when resolved within the provided symbol table 362 /// operation. Returns failure if any of the nested references could not be 363 /// resolved. 364 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, 365 SmallVectorImpl<Operation *> &symbols); 366 367 private: 368 /// Get the symbol table for the symbol table operation, constructing if it 369 /// does not exist. This function provides thread safety over `collection` 370 /// by locking when performing the lookup and when inserting 371 /// lazily-constructed symbol tables. 372 SymbolTable &getSymbolTable(Operation *symbolTableOp); 373 374 /// The symbol tables to manage. 375 SymbolTableCollection &collection; 376 /// The mutex protecting access to the symbol table collection. 377 llvm::sys::SmartRWMutex<true> mutex; 378 }; 379 380 //===----------------------------------------------------------------------===// 381 // SymbolUserMap 382 //===----------------------------------------------------------------------===// 383 384 /// This class represents a map of symbols to users, and provides efficient 385 /// implementations of symbol queries related to users; such as collecting the 386 /// users of a symbol, replacing all uses, etc. 387 class SymbolUserMap { 388 public: 389 /// Build a user map for all of the symbols defined in regions nested under 390 /// 'symbolTableOp'. A reference to the provided symbol table collection is 391 /// kept by the user map to ensure efficient lookups, thus the lifetime should 392 /// extend beyond that of this map. 393 SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp); 394 395 /// Return the users of the provided symbol operation. getUsers(Operation * symbol)396 ArrayRef<Operation *> getUsers(Operation *symbol) const { 397 auto it = symbolToUsers.find(symbol); 398 return it != symbolToUsers.end() ? it->second.getArrayRef() : std::nullopt; 399 } 400 401 /// Return true if the given symbol has no uses. useEmpty(Operation * symbol)402 bool useEmpty(Operation *symbol) const { 403 return !symbolToUsers.count(symbol); 404 } 405 406 /// Replace all of the uses of the given symbol with `newSymbolName`. 407 void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName); 408 409 private: 410 /// A reference to the symbol table used to construct this map. 411 SymbolTableCollection &symbolTable; 412 413 /// A map of symbol operations to symbol users. 414 DenseMap<Operation *, SetVector<Operation *>> symbolToUsers; 415 }; 416 417 //===----------------------------------------------------------------------===// 418 // SymbolTable Trait Types 419 //===----------------------------------------------------------------------===// 420 421 namespace detail { 422 LogicalResult verifySymbolTable(Operation *op); 423 LogicalResult verifySymbol(Operation *op); 424 } // namespace detail 425 426 namespace OpTrait { 427 /// A trait used to provide symbol table functionalities to a region operation. 428 /// This operation must hold exactly 1 region. Once attached, all operations 429 /// that are directly within the region, i.e not including those within child 430 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will 431 /// be verified to ensure that the names are uniqued. These operations must also 432 /// adhere to the constraints defined by the `Symbol` trait, even if they do not 433 /// inherit from it. 434 template <typename ConcreteType> 435 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { 436 public: verifyRegionTrait(Operation * op)437 static LogicalResult verifyRegionTrait(Operation *op) { 438 return ::mlir::detail::verifySymbolTable(op); 439 } 440 441 /// Look up a symbol with the specified name, returning null if no such 442 /// name exists. Symbol names never include the @ on them. Note: This 443 /// performs a linear scan of held symbols. lookupSymbol(StringAttr name)444 Operation *lookupSymbol(StringAttr name) { 445 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 446 } 447 template <typename T> lookupSymbol(StringAttr name)448 T lookupSymbol(StringAttr name) { 449 return dyn_cast_or_null<T>(lookupSymbol(name)); 450 } lookupSymbol(SymbolRefAttr symbol)451 Operation *lookupSymbol(SymbolRefAttr symbol) { 452 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); 453 } 454 template <typename T> lookupSymbol(SymbolRefAttr symbol)455 T lookupSymbol(SymbolRefAttr symbol) { 456 return dyn_cast_or_null<T>(lookupSymbol(symbol)); 457 } 458 lookupSymbol(StringRef name)459 Operation *lookupSymbol(StringRef name) { 460 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 461 } 462 template <typename T> lookupSymbol(StringRef name)463 T lookupSymbol(StringRef name) { 464 return dyn_cast_or_null<T>(lookupSymbol(name)); 465 } 466 }; 467 468 } // namespace OpTrait 469 470 //===----------------------------------------------------------------------===// 471 // Visibility parsing implementation. 472 //===----------------------------------------------------------------------===// 473 474 namespace impl { 475 /// Parse an optional visibility attribute keyword (i.e., public, private, or 476 /// nested) without quotes in a string attribute named 'attrName'. 477 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, 478 NamedAttrList &attrs); 479 } // namespace impl 480 481 } // namespace mlir 482 483 /// Include the generated symbol interfaces. 484 #include "mlir/IR/SymbolInterfaces.h.inc" 485 486 #endif // MLIR_IR_SYMBOLTABLE_H 487