xref: /llvm-project/mlir/include/mlir/IR/SymbolTable.h (revision ea84897ba3e7727a3aa3fbd6d84b6b4ab573c70d)
1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/StringMap.h"
16 #include "llvm/Support/RWMutex.h"
17 
18 namespace mlir {
19 
20 /// This class allows for representing and managing the symbol table used by
21 /// operations with the 'SymbolTable' trait. Inserting into and erasing from
22 /// this SymbolTable will also insert and erase from the Operation given to it
23 /// at construction.
24 class SymbolTable {
25 public:
26   /// Build a symbol table with the symbols within the given operation.
27   SymbolTable(Operation *symbolTableOp);
28 
29   /// Look up a symbol with the specified name, returning null if no such
30   /// name exists. Names never include the @ on them.
31   Operation *lookup(StringRef name) const;
32   template <typename T>
lookup(StringRef name)33   T lookup(StringRef name) const {
34     return dyn_cast_or_null<T>(lookup(name));
35   }
36 
37   /// Look up a symbol with the specified name, returning null if no such
38   /// name exists. Names never include the @ on them.
39   Operation *lookup(StringAttr name) const;
40   template <typename T>
lookup(StringAttr name)41   T lookup(StringAttr name) const {
42     return dyn_cast_or_null<T>(lookup(name));
43   }
44 
45   /// Remove the given symbol from the table, without deleting it.
46   void remove(Operation *op);
47 
48   /// Erase the given symbol from the table and delete the operation.
49   void erase(Operation *symbol);
50 
51   /// Insert a new symbol into the table, and rename it as necessary to avoid
52   /// collisions. Also insert at the specified location in the body of the
53   /// associated operation if it is not already there. It is asserted that the
54   /// symbol is not inside another operation. Return the name of the symbol
55   /// after insertion as attribute.
56   StringAttr insert(Operation *symbol, Block::iterator insertPt = {});
57 
58   /// Renames the given op or the op refered to by the given name to the given
59   /// new name and updates the symbol table and all usages of the symbol
60   /// accordingly. Fails if the updating of the usages fails.
61   LogicalResult rename(StringAttr from, StringAttr to);
62   LogicalResult rename(Operation *op, StringAttr to);
63   LogicalResult rename(StringAttr from, StringRef to);
64   LogicalResult rename(Operation *op, StringRef to);
65 
66   /// Renames the given op or the op refered to by the given name to the a name
67   /// that is unique within this and the provided other symbol tables and
68   /// updates the symbol table and all usages of the symbol accordingly. Returns
69   /// the new name or failure if the renaming fails.
70   FailureOr<StringAttr> renameToUnique(StringAttr from,
71                                        ArrayRef<SymbolTable *> others);
72   FailureOr<StringAttr> renameToUnique(Operation *op,
73                                        ArrayRef<SymbolTable *> others);
74 
75   /// Return the name of the attribute used for symbol names.
getSymbolAttrName()76   static StringRef getSymbolAttrName() { return "sym_name"; }
77 
78   /// Returns the associated operation.
getOp()79   Operation *getOp() const { return symbolTableOp; }
80 
81   /// Return the name of the attribute used for symbol visibility.
getVisibilityAttrName()82   static StringRef getVisibilityAttrName() { return "sym_visibility"; }
83 
84   //===--------------------------------------------------------------------===//
85   // Symbol Utilities
86   //===--------------------------------------------------------------------===//
87 
88   /// An enumeration detailing the different visibility types that a symbol may
89   /// have.
90   enum class Visibility {
91     /// The symbol is public and may be referenced anywhere internal or external
92     /// to the visible references in the IR.
93     Public,
94 
95     /// The symbol is private and may only be referenced by SymbolRefAttrs local
96     /// to the operations within the current symbol table.
97     Private,
98 
99     /// The symbol is visible to the current IR, which may include operations in
100     /// symbol tables above the one that owns the current symbol. `Nested`
101     /// visibility allows for referencing a symbol outside of its current symbol
102     /// table, while retaining the ability to observe all uses.
103     Nested,
104   };
105 
106   /// Generate a unique symbol name. Iteratively increase uniquingCounter
107   /// and use it as a suffix for symbol names until uniqueChecker does not
108   /// detect any conflict.
109   template <unsigned N, typename UniqueChecker>
generateSymbolName(StringRef name,UniqueChecker uniqueChecker,unsigned & uniquingCounter)110   static SmallString<N> generateSymbolName(StringRef name,
111                                            UniqueChecker uniqueChecker,
112                                            unsigned &uniquingCounter) {
113     SmallString<N> nameBuffer(name);
114     unsigned originalLength = nameBuffer.size();
115     do {
116       nameBuffer.resize(originalLength);
117       nameBuffer += '_';
118       nameBuffer += std::to_string(uniquingCounter++);
119     } while (uniqueChecker(nameBuffer));
120 
121     return nameBuffer;
122   }
123 
124   /// Returns the name of the given symbol operation, aborting if no symbol is
125   /// present.
126   static StringAttr getSymbolName(Operation *symbol);
127 
128   /// Sets the name of the given symbol operation.
129   static void setSymbolName(Operation *symbol, StringAttr name);
setSymbolName(Operation * symbol,StringRef name)130   static void setSymbolName(Operation *symbol, StringRef name) {
131     setSymbolName(symbol, StringAttr::get(symbol->getContext(), name));
132   }
133 
134   /// Returns the visibility of the given symbol operation.
135   static Visibility getSymbolVisibility(Operation *symbol);
136   /// Sets the visibility of the given symbol operation.
137   static void setSymbolVisibility(Operation *symbol, Visibility vis);
138 
139   /// Returns the nearest symbol table from a given operation `from`. Returns
140   /// nullptr if no valid parent symbol table could be found.
141   static Operation *getNearestSymbolTable(Operation *from);
142 
143   /// Walks all symbol table operations nested within, and including, `op`. For
144   /// each symbol table operation, the provided callback is invoked with the op
145   /// and a boolean signifying if the symbols within that symbol table can be
146   /// treated as if all uses within the IR are visible to the caller.
147   /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
148   /// within `op` are visible.
149   static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
150                                function_ref<void(Operation *, bool)> callback);
151 
152   /// Returns the operation registered with the given symbol name with the
153   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
154   /// with the 'OpTrait::SymbolTable' trait.
155   static Operation *lookupSymbolIn(Operation *op, StringAttr symbol);
lookupSymbolIn(Operation * op,StringRef symbol)156   static Operation *lookupSymbolIn(Operation *op, StringRef symbol) {
157     return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol));
158   }
159   static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
160   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
161   /// by a given SymbolRefAttr. Returns failure if any of the nested references
162   /// could not be resolved.
163   static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
164                                       SmallVectorImpl<Operation *> &symbols);
165 
166   /// Returns the operation registered with the given symbol name within the
167   /// closest parent operation of, or including, 'from' with the
168   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
169   /// found.
170   static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
171   static Operation *lookupNearestSymbolFrom(Operation *from,
172                                             SymbolRefAttr symbol);
173   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringAttr symbol)174   static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
175     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
176   }
177   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)178   static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
179     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
180   }
181 
182   /// This class represents a specific symbol use.
183   class SymbolUse {
184   public:
SymbolUse(Operation * op,SymbolRefAttr symbolRef)185     SymbolUse(Operation *op, SymbolRefAttr symbolRef)
186         : owner(op), symbolRef(symbolRef) {}
187 
188     /// Return the operation user of this symbol reference.
getUser()189     Operation *getUser() const { return owner; }
190 
191     /// Return the symbol reference that this use represents.
getSymbolRef()192     SymbolRefAttr getSymbolRef() const { return symbolRef; }
193 
194   private:
195     /// The operation that this access is held by.
196     Operation *owner;
197 
198     /// The symbol reference that this use represents.
199     SymbolRefAttr symbolRef;
200   };
201 
202   /// This class implements a range of SymbolRef uses.
203   class UseRange {
204   public:
UseRange(std::vector<SymbolUse> && uses)205     UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
206 
207     using iterator = std::vector<SymbolUse>::const_iterator;
begin()208     iterator begin() const { return uses.begin(); }
end()209     iterator end() const { return uses.end(); }
empty()210     bool empty() const { return uses.empty(); }
211 
212   private:
213     std::vector<SymbolUse> uses;
214   };
215 
216   /// Get an iterator range for all of the uses, for any symbol, that are nested
217   /// within the given operation 'from'. This does not traverse into any nested
218   /// symbol tables. This function returns std::nullopt if there are any unknown
219   /// operations that may potentially be symbol tables.
220   static std::optional<UseRange> getSymbolUses(Operation *from);
221   static std::optional<UseRange> getSymbolUses(Region *from);
222 
223   /// Get all of the uses of the given symbol that are nested within the given
224   /// operation 'from'. This does not traverse into any nested symbol tables.
225   /// This function returns std::nullopt if there are any unknown operations
226   /// that may potentially be symbol tables.
227   static std::optional<UseRange> getSymbolUses(StringAttr symbol,
228                                                Operation *from);
229   static std::optional<UseRange> getSymbolUses(Operation *symbol,
230                                                Operation *from);
231   static std::optional<UseRange> getSymbolUses(StringAttr symbol, Region *from);
232   static std::optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
233 
234   /// Return if the given symbol is known to have no uses that are nested
235   /// within the given operation 'from'. This does not traverse into any nested
236   /// symbol tables. This function will also return false if there are any
237   /// unknown operations that may potentially be symbol tables. This doesn't
238   /// necessarily mean that there are no uses, we just can't conservatively
239   /// prove it.
240   static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from);
241   static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
242   static bool symbolKnownUseEmpty(StringAttr symbol, Region *from);
243   static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
244 
245   /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
246   /// provided symbol 'newSymbol' that are nested within the given operation
247   /// 'from'. This does not traverse into any nested symbol tables. If there are
248   /// any unknown operations that may potentially be symbol tables, no uses are
249   /// replaced and failure is returned.
250   static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
251                                             StringAttr newSymbol,
252                                             Operation *from);
253   static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
254                                             StringAttr newSymbolName,
255                                             Operation *from);
256   static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
257                                             StringAttr newSymbol, Region *from);
258   static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
259                                             StringAttr newSymbolName,
260                                             Region *from);
261 
262 private:
263   Operation *symbolTableOp;
264 
265   /// This is a mapping from a name to the symbol with that name.  They key is
266   /// always known to be a StringAttr.
267   DenseMap<Attribute, Operation *> symbolTable;
268 
269   /// This is used when name conflicts are detected.
270   unsigned uniquingCounter = 0;
271 };
272 
273 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility);
274 
275 //===----------------------------------------------------------------------===//
276 // SymbolTableCollection
277 //===----------------------------------------------------------------------===//
278 
279 /// This class represents a collection of `SymbolTable`s. This simplifies
280 /// certain algorithms that run recursively on nested symbol tables. Symbol
281 /// tables are constructed lazily to reduce the upfront cost of constructing
282 /// unnecessary tables.
283 class SymbolTableCollection {
284 public:
285   /// Look up a symbol with the specified name within the specified symbol table
286   /// operation, returning null if no such name exists.
287   Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
288   Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
289   template <typename T, typename NameT>
lookupSymbolIn(Operation * symbolTableOp,NameT && name)290   T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
291     return dyn_cast_or_null<T>(
292         lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
293   }
294   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
295   /// by a given SymbolRefAttr when resolved within the provided symbol table
296   /// operation. Returns failure if any of the nested references could not be
297   /// resolved.
298   LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
299                                SmallVectorImpl<Operation *> &symbols);
300 
301   /// Returns the operation registered with the given symbol name within the
302   /// closest parent operation of, or including, 'from' with the
303   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
304   /// found.
305   Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
306   Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
307   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringAttr symbol)308   T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
309     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
310   }
311   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)312   T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
313     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
314   }
315 
316   /// Lookup, or create, a symbol table for an operation.
317   SymbolTable &getSymbolTable(Operation *op);
318 
319 private:
320   friend class LockedSymbolTableCollection;
321 
322   /// The constructed symbol tables nested within this table.
323   DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
324 };
325 
326 //===----------------------------------------------------------------------===//
327 // LockedSymbolTableCollection
328 //===----------------------------------------------------------------------===//
329 
330 /// This class implements a lock-based shared wrapper around a symbol table
331 /// collection that allows shared access to the collection of symbol tables.
332 /// This class does not protect shared access to individual symbol tables.
333 /// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for
334 /// symbol table operations, making read operations not thread-safe. This class
335 /// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the
336 /// lazy `SymbolTable` lookup.
337 class LockedSymbolTableCollection : public SymbolTableCollection {
338 public:
LockedSymbolTableCollection(SymbolTableCollection & collection)339   explicit LockedSymbolTableCollection(SymbolTableCollection &collection)
340       : collection(collection) {}
341 
342   /// Look up a symbol with the specified name within the specified symbol table
343   /// operation, returning null if no such name exists.
344   Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
345   /// Look up a symbol with the specified name within the specified symbol table
346   /// operation, returning null if no such name exists.
347   Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol);
348   /// Look up a potentially nested symbol within the specified symbol table
349   /// operation, returning null if no such symbol exists.
350   Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
351 
352   /// Lookup a symbol of a particular kind within the specified symbol table,
353   /// returning null if the symbol was not found.
354   template <typename T, typename NameT>
lookupSymbolIn(Operation * symbolTableOp,NameT && name)355   T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
356     return dyn_cast_or_null<T>(
357         lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
358   }
359 
360   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
361   /// by a given SymbolRefAttr when resolved within the provided symbol table
362   /// operation. Returns failure if any of the nested references could not be
363   /// resolved.
364   LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
365                                SmallVectorImpl<Operation *> &symbols);
366 
367 private:
368   /// Get the symbol table for the symbol table operation, constructing if it
369   /// does not exist. This function provides thread safety over `collection`
370   /// by locking when performing the lookup and when inserting
371   /// lazily-constructed symbol tables.
372   SymbolTable &getSymbolTable(Operation *symbolTableOp);
373 
374   /// The symbol tables to manage.
375   SymbolTableCollection &collection;
376   /// The mutex protecting access to the symbol table collection.
377   llvm::sys::SmartRWMutex<true> mutex;
378 };
379 
380 //===----------------------------------------------------------------------===//
381 // SymbolUserMap
382 //===----------------------------------------------------------------------===//
383 
384 /// This class represents a map of symbols to users, and provides efficient
385 /// implementations of symbol queries related to users; such as collecting the
386 /// users of a symbol, replacing all uses, etc.
387 class SymbolUserMap {
388 public:
389   /// Build a user map for all of the symbols defined in regions nested under
390   /// 'symbolTableOp'. A reference to the provided symbol table collection is
391   /// kept by the user map to ensure efficient lookups, thus the lifetime should
392   /// extend beyond that of this map.
393   SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
394 
395   /// Return the users of the provided symbol operation.
getUsers(Operation * symbol)396   ArrayRef<Operation *> getUsers(Operation *symbol) const {
397     auto it = symbolToUsers.find(symbol);
398     return it != symbolToUsers.end() ? it->second.getArrayRef() : std::nullopt;
399   }
400 
401   /// Return true if the given symbol has no uses.
useEmpty(Operation * symbol)402   bool useEmpty(Operation *symbol) const {
403     return !symbolToUsers.count(symbol);
404   }
405 
406   /// Replace all of the uses of the given symbol with `newSymbolName`.
407   void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName);
408 
409 private:
410   /// A reference to the symbol table used to construct this map.
411   SymbolTableCollection &symbolTable;
412 
413   /// A map of symbol operations to symbol users.
414   DenseMap<Operation *, SetVector<Operation *>> symbolToUsers;
415 };
416 
417 //===----------------------------------------------------------------------===//
418 // SymbolTable Trait Types
419 //===----------------------------------------------------------------------===//
420 
421 namespace detail {
422 LogicalResult verifySymbolTable(Operation *op);
423 LogicalResult verifySymbol(Operation *op);
424 } // namespace detail
425 
426 namespace OpTrait {
427 /// A trait used to provide symbol table functionalities to a region operation.
428 /// This operation must hold exactly 1 region. Once attached, all operations
429 /// that are directly within the region, i.e not including those within child
430 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
431 /// be verified to ensure that the names are uniqued. These operations must also
432 /// adhere to the constraints defined by the `Symbol` trait, even if they do not
433 /// inherit from it.
434 template <typename ConcreteType>
435 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
436 public:
verifyRegionTrait(Operation * op)437   static LogicalResult verifyRegionTrait(Operation *op) {
438     return ::mlir::detail::verifySymbolTable(op);
439   }
440 
441   /// Look up a symbol with the specified name, returning null if no such
442   /// name exists. Symbol names never include the @ on them. Note: This
443   /// performs a linear scan of held symbols.
lookupSymbol(StringAttr name)444   Operation *lookupSymbol(StringAttr name) {
445     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
446   }
447   template <typename T>
lookupSymbol(StringAttr name)448   T lookupSymbol(StringAttr name) {
449     return dyn_cast_or_null<T>(lookupSymbol(name));
450   }
lookupSymbol(SymbolRefAttr symbol)451   Operation *lookupSymbol(SymbolRefAttr symbol) {
452     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
453   }
454   template <typename T>
lookupSymbol(SymbolRefAttr symbol)455   T lookupSymbol(SymbolRefAttr symbol) {
456     return dyn_cast_or_null<T>(lookupSymbol(symbol));
457   }
458 
lookupSymbol(StringRef name)459   Operation *lookupSymbol(StringRef name) {
460     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
461   }
462   template <typename T>
lookupSymbol(StringRef name)463   T lookupSymbol(StringRef name) {
464     return dyn_cast_or_null<T>(lookupSymbol(name));
465   }
466 };
467 
468 } // namespace OpTrait
469 
470 //===----------------------------------------------------------------------===//
471 // Visibility parsing implementation.
472 //===----------------------------------------------------------------------===//
473 
474 namespace impl {
475 /// Parse an optional visibility attribute keyword (i.e., public, private, or
476 /// nested) without quotes in a string attribute named 'attrName'.
477 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
478                                            NamedAttrList &attrs);
479 } // namespace impl
480 
481 } // namespace mlir
482 
483 /// Include the generated symbol interfaces.
484 #include "mlir/IR/SymbolInterfaces.h.inc"
485 
486 #endif // MLIR_IR_SYMBOLTABLE_H
487