xref: /llvm-project/mlir/lib/IR/SymbolTable.cpp (revision 01eedbc7c14859c273bbd98693c67f35c59e8d85)
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
16 
17 using namespace mlir;
18 
19 /// Return true if the given operation is unknown and may potentially define a
20 /// symbol table.
21 static bool isPotentiallyUnknownSymbolTable(Operation *op) {
22   return op->getNumRegions() == 1 && !op->getDialect();
23 }
24 
25 /// Returns the string name of the given symbol, or null if this is not a
26 /// symbol.
27 static StringAttr getNameIfSymbol(Operation *op) {
28   return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
29 }
30 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
31   return op->getAttrOfType<StringAttr>(symbolAttrNameId);
32 }
33 
34 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
35 /// that are usable within the symbol table operations from 'symbol' as far up
36 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
37 /// Returns success if all references up to 'within' could be computed.
38 static LogicalResult
39 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
40                           Operation *within,
41                           SmallVectorImpl<SymbolRefAttr> &results) {
42   assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
43   MLIRContext *ctx = symbol->getContext();
44 
45   auto leafRef = FlatSymbolRefAttr::get(symbolName);
46   results.push_back(leafRef);
47 
48   // Early exit for when 'within' is the parent of 'symbol'.
49   Operation *symbolTableOp = symbol->getParentOp();
50   if (within == symbolTableOp)
51     return success();
52 
53   // Collect references until 'symbolTableOp' reaches 'within'.
54   SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
55   StringAttr symbolNameId =
56       StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
57   do {
58     // Each parent of 'symbol' should define a symbol table.
59     if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
60       return failure();
61     // Each parent of 'symbol' should also be a symbol.
62     StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
63     if (!symbolTableName)
64       return failure();
65     results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
66 
67     symbolTableOp = symbolTableOp->getParentOp();
68     if (symbolTableOp == within)
69       break;
70     nestedRefs.insert(nestedRefs.begin(),
71                       FlatSymbolRefAttr::get(symbolTableName));
72   } while (true);
73   return success();
74 }
75 
76 /// Walk all of the operations within the given set of regions, without
77 /// traversing into any nested symbol tables. Stops walking if the result of the
78 /// callback is anything other than `WalkResult::advance`.
79 static Optional<WalkResult>
80 walkSymbolTable(MutableArrayRef<Region> regions,
81                 function_ref<Optional<WalkResult>(Operation *)> callback) {
82   SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
83   while (!worklist.empty()) {
84     for (Operation &op : worklist.pop_back_val()->getOps()) {
85       Optional<WalkResult> result = callback(&op);
86       if (result != WalkResult::advance())
87         return result;
88 
89       // If this op defines a new symbol table scope, we can't traverse. Any
90       // symbol references nested within 'op' are different semantically.
91       if (!op.hasTrait<OpTrait::SymbolTable>()) {
92         for (Region &region : op.getRegions())
93           worklist.push_back(&region);
94       }
95     }
96   }
97   return WalkResult::advance();
98 }
99 
100 /// Walk all of the operations nested under, and including, the given operation,
101 /// without traversing into any nested symbol tables. Stops walking if the
102 /// result of the callback is anything other than `WalkResult::advance`.
103 static Optional<WalkResult>
104 walkSymbolTable(Operation *op,
105                 function_ref<Optional<WalkResult>(Operation *)> callback) {
106   Optional<WalkResult> result = callback(op);
107   if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
108     return result;
109   return walkSymbolTable(op->getRegions(), callback);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // SymbolTable
114 //===----------------------------------------------------------------------===//
115 
116 /// Build a symbol table with the symbols within the given operation.
117 SymbolTable::SymbolTable(Operation *symbolTableOp)
118     : symbolTableOp(symbolTableOp) {
119   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
120          "expected operation to have SymbolTable trait");
121   assert(symbolTableOp->getNumRegions() == 1 &&
122          "expected operation to have a single region");
123   assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
124          "expected operation to have a single block");
125 
126   StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
127                                             SymbolTable::getSymbolAttrName());
128   for (auto &op : symbolTableOp->getRegion(0).front()) {
129     StringAttr name = getNameIfSymbol(&op, symbolNameId);
130     if (!name)
131       continue;
132 
133     auto inserted = symbolTable.insert({name, &op});
134     (void)inserted;
135     assert(inserted.second &&
136            "expected region to contain uniquely named symbol operations");
137   }
138 }
139 
140 /// Look up a symbol with the specified name, returning null if no such name
141 /// exists. Names never include the @ on them.
142 Operation *SymbolTable::lookup(StringRef name) const {
143   return lookup(StringAttr::get(symbolTableOp->getContext(), name));
144 }
145 Operation *SymbolTable::lookup(StringAttr name) const {
146   return symbolTable.lookup(name);
147 }
148 
149 /// Erase the given symbol from the table.
150 void SymbolTable::erase(Operation *symbol) {
151   StringAttr name = getNameIfSymbol(symbol);
152   assert(name && "expected valid 'name' attribute");
153   assert(symbol->getParentOp() == symbolTableOp &&
154          "expected this operation to be inside of the operation with this "
155          "SymbolTable");
156 
157   auto it = symbolTable.find(name);
158   if (it != symbolTable.end() && it->second == symbol) {
159     symbolTable.erase(it);
160     symbol->erase();
161   }
162 }
163 
164 // TODO: Consider if this should be renamed to something like insertOrUpdate
165 /// Insert a new symbol into the table and associated operation if not already
166 /// there and rename it as necessary to avoid collisions. Return the name of
167 /// the symbol after insertion as attribute.
168 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
169   // The symbol cannot be the child of another op and must be the child of the
170   // symbolTableOp after this.
171   //
172   // TODO: consider if SymbolTable's constructor should behave the same.
173   if (!symbol->getParentOp()) {
174     auto &body = symbolTableOp->getRegion(0).front();
175     if (insertPt == Block::iterator()) {
176       insertPt = Block::iterator(body.end());
177     } else {
178       assert((insertPt == body.end() ||
179               insertPt->getParentOp() == symbolTableOp) &&
180              "expected insertPt to be in the associated module operation");
181     }
182     // Insert before the terminator, if any.
183     if (insertPt == Block::iterator(body.end()) && !body.empty() &&
184         std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
185       insertPt = std::prev(body.end());
186 
187     body.getOperations().insert(insertPt, symbol);
188   }
189   assert(symbol->getParentOp() == symbolTableOp &&
190          "symbol is already inserted in another op");
191 
192   // Add this symbol to the symbol table, uniquing the name if a conflict is
193   // detected.
194   StringAttr name = getSymbolName(symbol);
195   if (symbolTable.insert({name, symbol}).second)
196     return name;
197   // If the symbol was already in the table, also return.
198   if (symbolTable.lookup(name) == symbol)
199     return name;
200   // If a conflict was detected, then the symbol will not have been added to
201   // the symbol table. Try suffixes until we get to a unique name that works.
202   SmallString<128> nameBuffer(name.getValue());
203   unsigned originalLength = nameBuffer.size();
204 
205   MLIRContext *context = symbol->getContext();
206 
207   // Iteratively try suffixes until we find one that isn't used.
208   do {
209     nameBuffer.resize(originalLength);
210     nameBuffer += '_';
211     nameBuffer += std::to_string(uniquingCounter++);
212   } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
213                 .second);
214   setSymbolName(symbol, nameBuffer);
215   return getSymbolName(symbol);
216 }
217 
218 /// Returns the name of the given symbol operation.
219 StringAttr SymbolTable::getSymbolName(Operation *symbol) {
220   StringAttr name = getNameIfSymbol(symbol);
221   assert(name && "expected valid symbol name");
222   return name;
223 }
224 
225 /// Sets the name of the given symbol operation.
226 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
227   symbol->setAttr(getSymbolAttrName(), name);
228 }
229 
230 /// Returns the visibility of the given symbol operation.
231 SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
232   // If the attribute doesn't exist, assume public.
233   StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
234   if (!vis)
235     return Visibility::Public;
236 
237   // Otherwise, switch on the string value.
238   return StringSwitch<Visibility>(vis.getValue())
239       .Case("private", Visibility::Private)
240       .Case("nested", Visibility::Nested)
241       .Case("public", Visibility::Public);
242 }
243 /// Sets the visibility of the given symbol operation.
244 void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
245   MLIRContext *ctx = symbol->getContext();
246 
247   // If the visibility is public, just drop the attribute as this is the
248   // default.
249   if (vis == Visibility::Public) {
250     symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
251     return;
252   }
253 
254   // Otherwise, update the attribute.
255   assert((vis == Visibility::Private || vis == Visibility::Nested) &&
256          "unknown symbol visibility kind");
257 
258   StringRef visName = vis == Visibility::Private ? "private" : "nested";
259   symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
260 }
261 
262 /// Returns the nearest symbol table from a given operation `from`. Returns
263 /// nullptr if no valid parent symbol table could be found.
264 Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
265   assert(from && "expected valid operation");
266   if (isPotentiallyUnknownSymbolTable(from))
267     return nullptr;
268 
269   while (!from->hasTrait<OpTrait::SymbolTable>()) {
270     from = from->getParentOp();
271 
272     // Check that this is a valid op and isn't an unknown symbol table.
273     if (!from || isPotentiallyUnknownSymbolTable(from))
274       return nullptr;
275   }
276   return from;
277 }
278 
279 /// Walks all symbol table operations nested within, and including, `op`. For
280 /// each symbol table operation, the provided callback is invoked with the op
281 /// and a boolean signifying if the symbols within that symbol table can be
282 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
283 /// all of the symbol uses of symbols within `op` are visible.
284 void SymbolTable::walkSymbolTables(
285     Operation *op, bool allSymUsesVisible,
286     function_ref<void(Operation *, bool)> callback) {
287   bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
288   if (isSymbolTable) {
289     SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
290     allSymUsesVisible |= !symbol || symbol.isPrivate();
291   } else {
292     // Otherwise if 'op' is not a symbol table, any nested symbols are
293     // guaranteed to be hidden.
294     allSymUsesVisible = true;
295   }
296 
297   for (Region &region : op->getRegions())
298     for (Block &block : region)
299       for (Operation &nestedOp : block)
300         walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
301 
302   // If 'op' had the symbol table trait, visit it after any nested symbol
303   // tables.
304   if (isSymbolTable)
305     callback(op, allSymUsesVisible);
306 }
307 
308 /// Returns the operation registered with the given symbol name with the
309 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
310 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
311 /// was found.
312 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
313                                        StringAttr symbol) {
314   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
315   Region &region = symbolTableOp->getRegion(0);
316   if (region.empty())
317     return nullptr;
318 
319   // Look for a symbol with the given name.
320   StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
321                                             SymbolTable::getSymbolAttrName());
322   for (auto &op : region.front())
323     if (getNameIfSymbol(&op, symbolNameId) == symbol)
324       return &op;
325   return nullptr;
326 }
327 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
328                                        SymbolRefAttr symbol) {
329   SmallVector<Operation *, 4> resolvedSymbols;
330   if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
331     return nullptr;
332   return resolvedSymbols.back();
333 }
334 
335 /// Internal implementation of `lookupSymbolIn` that allows for specialized
336 /// implementations of the lookup function.
337 static LogicalResult lookupSymbolInImpl(
338     Operation *symbolTableOp, SymbolRefAttr symbol,
339     SmallVectorImpl<Operation *> &symbols,
340     function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
341   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
342 
343   // Lookup the root reference for this symbol.
344   symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
345   if (!symbolTableOp)
346     return failure();
347   symbols.push_back(symbolTableOp);
348 
349   // If there are no nested references, just return the root symbol directly.
350   ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
351   if (nestedRefs.empty())
352     return success();
353 
354   // Verify that the root is also a symbol table.
355   if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
356     return failure();
357 
358   // Otherwise, lookup each of the nested non-leaf references and ensure that
359   // each corresponds to a valid symbol table.
360   for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
361     symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
362     if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
363       return failure();
364     symbols.push_back(symbolTableOp);
365   }
366   symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
367   return success(symbols.back());
368 }
369 
370 LogicalResult
371 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
372                             SmallVectorImpl<Operation *> &symbols) {
373   auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
374     return lookupSymbolIn(symbolTableOp, symbol);
375   };
376   return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
377 }
378 
379 /// Returns the operation registered with the given symbol name within the
380 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
381 /// nullptr if no valid symbol was found.
382 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
383                                                 StringAttr symbol) {
384   Operation *symbolTableOp = getNearestSymbolTable(from);
385   return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
386 }
387 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
388                                                 SymbolRefAttr symbol) {
389   Operation *symbolTableOp = getNearestSymbolTable(from);
390   return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
391 }
392 
393 raw_ostream &mlir::operator<<(raw_ostream &os,
394                               SymbolTable::Visibility visibility) {
395   switch (visibility) {
396   case SymbolTable::Visibility::Public:
397     return os << "public";
398   case SymbolTable::Visibility::Private:
399     return os << "private";
400   case SymbolTable::Visibility::Nested:
401     return os << "nested";
402   }
403   llvm_unreachable("Unexpected visibility");
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // SymbolTable Trait Types
408 //===----------------------------------------------------------------------===//
409 
410 LogicalResult detail::verifySymbolTable(Operation *op) {
411   if (op->getNumRegions() != 1)
412     return op->emitOpError()
413            << "Operations with a 'SymbolTable' must have exactly one region";
414   if (!llvm::hasSingleElement(op->getRegion(0)))
415     return op->emitOpError()
416            << "Operations with a 'SymbolTable' must have exactly one block";
417 
418   // Check that all symbols are uniquely named within child regions.
419   DenseMap<Attribute, Location> nameToOrigLoc;
420   for (auto &block : op->getRegion(0)) {
421     for (auto &op : block) {
422       // Check for a symbol name attribute.
423       auto nameAttr =
424           op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
425       if (!nameAttr)
426         continue;
427 
428       // Try to insert this symbol into the table.
429       auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
430       if (!it.second)
431         return op.emitError()
432             .append("redefinition of symbol named '", nameAttr.getValue(), "'")
433             .attachNote(it.first->second)
434             .append("see existing symbol definition here");
435     }
436   }
437 
438   // Verify any nested symbol user operations.
439   SymbolTableCollection symbolTable;
440   auto verifySymbolUserFn = [&](Operation *op) -> Optional<WalkResult> {
441     if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
442       return WalkResult(user.verifySymbolUses(symbolTable));
443     return WalkResult::advance();
444   };
445 
446   Optional<WalkResult> result =
447       walkSymbolTable(op->getRegions(), verifySymbolUserFn);
448   return success(result && !result->wasInterrupted());
449 }
450 
451 LogicalResult detail::verifySymbol(Operation *op) {
452   // Verify the name attribute.
453   if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
454     return op->emitOpError() << "requires string attribute '"
455                              << mlir::SymbolTable::getSymbolAttrName() << "'";
456 
457   // Verify the visibility attribute.
458   if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
459     StringAttr visStrAttr = vis.dyn_cast<StringAttr>();
460     if (!visStrAttr)
461       return op->emitOpError() << "requires visibility attribute '"
462                                << mlir::SymbolTable::getVisibilityAttrName()
463                                << "' to be a string attribute, but got " << vis;
464 
465     if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
466                             visStrAttr.getValue()))
467       return op->emitOpError()
468              << "visibility expected to be one of [\"public\", \"private\", "
469                 "\"nested\"], but got "
470              << visStrAttr;
471   }
472   return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Symbol Use Lists
477 //===----------------------------------------------------------------------===//
478 
479 /// Walk all of the symbol references within the given operation, invoking the
480 /// provided callback for each found use. The callbacks takes the use of the
481 /// symbol.
482 static WalkResult
483 walkSymbolRefs(Operation *op,
484                function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
485   // Check to see if the operation has any attributes.
486   DictionaryAttr attrDict = op->getAttrDictionary();
487   if (attrDict.empty())
488     return WalkResult::advance();
489 
490   // A worklist of a container attribute and the current index into the held
491   // attribute list.
492   struct WorklistItem {
493     SubElementAttrInterface container;
494     SmallVector<Attribute> immediateSubElements;
495 
496     explicit WorklistItem(SubElementAttrInterface container) {
497       SmallVector<Attribute> subElements;
498       container.walkImmediateSubElements(
499           [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
500       immediateSubElements = std::move(subElements);
501     }
502   };
503 
504   SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
505   SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
506 
507   // Process the symbol references within the given nested attribute range.
508   auto processAttrs = [&](int &index,
509                           WorklistItem &worklistItem) -> WalkResult {
510     for (Attribute attr :
511          llvm::drop_begin(worklistItem.immediateSubElements, index)) {
512       // Invoke the provided callback if we find a symbol use and check for a
513       // requested interrupt.
514       if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) {
515         if (callback({op, symbolRef}).wasInterrupted())
516           return WalkResult::interrupt();
517 
518         /// Check for a nested container attribute, these will also need to be
519         /// walked.
520       } else if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
521         attrWorklist.emplace_back(interface);
522         curAccessChain.push_back(-1);
523         return WalkResult::advance();
524       }
525       // Make sure to keep the index counter in sync.
526       ++index;
527     }
528 
529     // Pop this container attribute from the worklist.
530     attrWorklist.pop_back();
531     curAccessChain.pop_back();
532     return WalkResult::advance();
533   };
534 
535   WalkResult result = WalkResult::advance();
536   do {
537     WorklistItem &item = attrWorklist.back();
538     int &index = curAccessChain.back();
539     ++index;
540 
541     // Process the given attribute, which is guaranteed to be a container.
542     result = processAttrs(index, item);
543   } while (!attrWorklist.empty() && !result.wasInterrupted());
544   return result;
545 }
546 
547 /// Walk all of the uses, for any symbol, that are nested within the given
548 /// regions, invoking the provided callback for each. This does not traverse
549 /// into any nested symbol tables.
550 static Optional<WalkResult>
551 walkSymbolUses(MutableArrayRef<Region> regions,
552                function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
553   return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
554     // Check that this isn't a potentially unknown symbol table.
555     if (isPotentiallyUnknownSymbolTable(op))
556       return llvm::None;
557 
558     return walkSymbolRefs(op, callback);
559   });
560 }
561 /// Walk all of the uses, for any symbol, that are nested within the given
562 /// operation 'from', invoking the provided callback for each. This does not
563 /// traverse into any nested symbol tables.
564 static Optional<WalkResult>
565 walkSymbolUses(Operation *from,
566                function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
567   // If this operation has regions, and it, as well as its dialect, isn't
568   // registered then conservatively fail. The operation may define a
569   // symbol table, so we can't opaquely know if we should traverse to find
570   // nested uses.
571   if (isPotentiallyUnknownSymbolTable(from))
572     return llvm::None;
573 
574   // Walk the uses on this operation.
575   if (walkSymbolRefs(from, callback).wasInterrupted())
576     return WalkResult::interrupt();
577 
578   // Only recurse if this operation is not a symbol table. A symbol table
579   // defines a new scope, so we can't walk the attributes from within the symbol
580   // table op.
581   if (!from->hasTrait<OpTrait::SymbolTable>())
582     return walkSymbolUses(from->getRegions(), callback);
583   return WalkResult::advance();
584 }
585 
586 namespace {
587 /// This class represents a single symbol scope. A symbol scope represents the
588 /// set of operations nested within a symbol table that may reference symbols
589 /// within that table. A symbol scope does not contain the symbol table
590 /// operation itself, just its contained operations. A scope ends at leaf
591 /// operations or another symbol table operation.
592 struct SymbolScope {
593   /// Walk the symbol uses within this scope, invoking the given callback.
594   /// This variant is used when the callback type matches that expected by
595   /// 'walkSymbolUses'.
596   template <typename CallbackT,
597             typename std::enable_if_t<!std::is_same<
598                 typename llvm::function_traits<CallbackT>::result_t,
599                 void>::value> * = nullptr>
600   Optional<WalkResult> walk(CallbackT cback) {
601     if (Region *region = limit.dyn_cast<Region *>())
602       return walkSymbolUses(*region, cback);
603     return walkSymbolUses(limit.get<Operation *>(), cback);
604   }
605   /// This variant is used when the callback type matches a stripped down type:
606   /// void(SymbolTable::SymbolUse use)
607   template <typename CallbackT,
608             typename std::enable_if_t<std::is_same<
609                 typename llvm::function_traits<CallbackT>::result_t,
610                 void>::value> * = nullptr>
611   Optional<WalkResult> walk(CallbackT cback) {
612     return walk([=](SymbolTable::SymbolUse use) {
613       return cback(use), WalkResult::advance();
614     });
615   }
616 
617   /// Walk all of the operations nested under the current scope without
618   /// traversing into any nested symbol tables.
619   template <typename CallbackT>
620   Optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
621     if (Region *region = limit.dyn_cast<Region *>())
622       return ::walkSymbolTable(*region, cback);
623     return ::walkSymbolTable(limit.get<Operation *>(), cback);
624   }
625 
626   /// The representation of the symbol within this scope.
627   SymbolRefAttr symbol;
628 
629   /// The IR unit representing this scope.
630   llvm::PointerUnion<Operation *, Region *> limit;
631 };
632 } // namespace
633 
634 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
635 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
636                                                        Operation *limit) {
637   StringAttr symName = SymbolTable::getSymbolName(symbol);
638   assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
639 
640   // Compute the ancestors of 'limit'.
641   SetVector<Operation *, SmallVector<Operation *, 4>,
642             SmallPtrSet<Operation *, 4>>
643       limitAncestors;
644   Operation *limitAncestor = limit;
645   do {
646     // Check to see if 'symbol' is an ancestor of 'limit'.
647     if (limitAncestor == symbol) {
648       // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
649       // doesn't support parent references.
650       if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
651           symbol->getParentOp())
652         return {{SymbolRefAttr::get(symName), limit}};
653       return {};
654     }
655 
656     limitAncestors.insert(limitAncestor);
657   } while ((limitAncestor = limitAncestor->getParentOp()));
658 
659   // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
660   Operation *commonAncestor = symbol->getParentOp();
661   do {
662     if (limitAncestors.count(commonAncestor))
663       break;
664   } while ((commonAncestor = commonAncestor->getParentOp()));
665   assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
666 
667   // Compute the set of valid nested references for 'symbol' as far up to the
668   // common ancestor as possible.
669   SmallVector<SymbolRefAttr, 2> references;
670   bool collectedAllReferences = succeeded(
671       collectValidReferencesFor(symbol, symName, commonAncestor, references));
672 
673   // Handle the case where the common ancestor is 'limit'.
674   if (commonAncestor == limit) {
675     SmallVector<SymbolScope, 2> scopes;
676 
677     // Walk each of the ancestors of 'symbol', calling the compute function for
678     // each one.
679     Operation *limitIt = symbol->getParentOp();
680     for (size_t i = 0, e = references.size(); i != e;
681          ++i, limitIt = limitIt->getParentOp()) {
682       assert(limitIt->hasTrait<OpTrait::SymbolTable>());
683       scopes.push_back({references[i], &limitIt->getRegion(0)});
684     }
685     return scopes;
686   }
687 
688   // Otherwise, we just need the symbol reference for 'symbol' that will be
689   // used within 'limit'. This is the last reference in the list we computed
690   // above if we were able to collect all references.
691   if (!collectedAllReferences)
692     return {};
693   return {{references.back(), limit}};
694 }
695 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
696                                                        Region *limit) {
697   auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
698 
699   // If we collected some scopes to walk, make sure to constrain the one for
700   // limit to the specific region requested.
701   if (!scopes.empty())
702     scopes.back().limit = limit;
703   return scopes;
704 }
705 template <typename IRUnit>
706 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
707                                                        IRUnit *limit) {
708   return {{SymbolRefAttr::get(symbol), limit}};
709 }
710 
711 /// Returns true if the given reference 'SubRef' is a sub reference of the
712 /// reference 'ref', i.e. 'ref' is a further qualified reference.
713 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
714   if (ref == subRef)
715     return true;
716 
717   // If the references are not pointer equal, check to see if `subRef` is a
718   // prefix of `ref`.
719   if (ref.isa<FlatSymbolRefAttr>() ||
720       ref.getRootReference() != subRef.getRootReference())
721     return false;
722 
723   auto refLeafs = ref.getNestedReferences();
724   auto subRefLeafs = subRef.getNestedReferences();
725   return subRefLeafs.size() < refLeafs.size() &&
726          subRefLeafs == refLeafs.take_front(subRefLeafs.size());
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // SymbolTable::getSymbolUses
731 
732 /// The implementation of SymbolTable::getSymbolUses below.
733 template <typename FromT>
734 static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
735   std::vector<SymbolTable::SymbolUse> uses;
736   auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
737     uses.push_back(symbolUse);
738     return WalkResult::advance();
739   };
740   auto result = walkSymbolUses(from, walkFn);
741   return result ? Optional<SymbolTable::UseRange>(std::move(uses)) : llvm::None;
742 }
743 
744 /// Get an iterator range for all of the uses, for any symbol, that are nested
745 /// within the given operation 'from'. This does not traverse into any nested
746 /// symbol tables, and will also only return uses on 'from' if it does not
747 /// also define a symbol table. This is because we treat the region as the
748 /// boundary of the symbol table, and not the op itself. This function returns
749 /// None if there are any unknown operations that may potentially be symbol
750 /// tables.
751 auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
752   return getSymbolUsesImpl(from);
753 }
754 auto SymbolTable::getSymbolUses(Region *from) -> Optional<UseRange> {
755   return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // SymbolTable::getSymbolUses
760 
761 /// The implementation of SymbolTable::getSymbolUses below.
762 template <typename SymbolT, typename IRUnitT>
763 static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
764                                                          IRUnitT *limit) {
765   std::vector<SymbolTable::SymbolUse> uses;
766   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
767     if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
768           if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
769             uses.push_back(symbolUse);
770         }))
771       return llvm::None;
772   }
773   return SymbolTable::UseRange(std::move(uses));
774 }
775 
776 /// Get all of the uses of the given symbol that are nested within the given
777 /// operation 'from', invoking the provided callback for each. This does not
778 /// traverse into any nested symbol tables. This function returns None if there
779 /// are any unknown operations that may potentially be symbol tables.
780 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
781     -> Optional<UseRange> {
782   return getSymbolUsesImpl(symbol, from);
783 }
784 auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
785     -> Optional<UseRange> {
786   return getSymbolUsesImpl(symbol, from);
787 }
788 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
789     -> Optional<UseRange> {
790   return getSymbolUsesImpl(symbol, from);
791 }
792 auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
793     -> Optional<UseRange> {
794   return getSymbolUsesImpl(symbol, from);
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // SymbolTable::symbolKnownUseEmpty
799 
800 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
801 template <typename SymbolT, typename IRUnitT>
802 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
803   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
804     // Walk all of the symbol uses looking for a reference to 'symbol'.
805     if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
806           return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
807                      ? WalkResult::interrupt()
808                      : WalkResult::advance();
809         }) != WalkResult::advance())
810       return false;
811   }
812   return true;
813 }
814 
815 /// Return if the given symbol is known to have no uses that are nested within
816 /// the given operation 'from'. This does not traverse into any nested symbol
817 /// tables. This function will also return false if there are any unknown
818 /// operations that may potentially be symbol tables.
819 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
820   return symbolKnownUseEmptyImpl(symbol, from);
821 }
822 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
823   return symbolKnownUseEmptyImpl(symbol, from);
824 }
825 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
826   return symbolKnownUseEmptyImpl(symbol, from);
827 }
828 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
829   return symbolKnownUseEmptyImpl(symbol, from);
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // SymbolTable::replaceAllSymbolUses
834 
835 /// Generates a new symbol reference attribute with a new leaf reference.
836 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
837                                         FlatSymbolRefAttr newLeafAttr) {
838   if (oldAttr.isa<FlatSymbolRefAttr>())
839     return newLeafAttr;
840   auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
841   nestedRefs.back() = newLeafAttr;
842   return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
843 }
844 
845 /// The implementation of SymbolTable::replaceAllSymbolUses below.
846 template <typename SymbolT, typename IRUnitT>
847 static LogicalResult
848 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
849   // Generate a new attribute to replace the given attribute.
850   FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
851   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
852     SymbolRefAttr oldAttr = scope.symbol;
853     SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
854 
855     auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
856       auto remapAttrFn = [&](Attribute attr) -> Attribute {
857         if (attr == oldAttr)
858           return newAttr;
859         // Handle prefix matches.
860         if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
861           if (isReferencePrefixOf(oldAttr, symRef)) {
862             auto oldNestedRefs = oldAttr.getNestedReferences();
863             auto nestedRefs = symRef.getNestedReferences();
864             if (oldNestedRefs.empty())
865               return SymbolRefAttr::get(newSymbol, nestedRefs);
866 
867             auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
868             newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
869             return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs);
870           }
871         }
872         return attr;
873       };
874       // Generate a new attribute dictionary by replacing references to the old
875       // symbol.
876       auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn);
877       if (!newDict)
878         return WalkResult::interrupt();
879 
880       op->setAttrs(newDict.template cast<DictionaryAttr>());
881       return WalkResult::advance();
882     };
883     if (!scope.walkSymbolTable(walkFn))
884       return failure();
885   }
886   return success();
887 }
888 
889 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
890 /// provided symbol 'newSymbol' that are nested within the given operation
891 /// 'from'. This does not traverse into any nested symbol tables. If there are
892 /// any unknown operations that may potentially be symbol tables, no uses are
893 /// replaced and failure is returned.
894 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
895                                                 StringAttr newSymbol,
896                                                 Operation *from) {
897   return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
898 }
899 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
900                                                 StringAttr newSymbol,
901                                                 Operation *from) {
902   return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
903 }
904 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
905                                                 StringAttr newSymbol,
906                                                 Region *from) {
907   return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
908 }
909 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
910                                                 StringAttr newSymbol,
911                                                 Region *from) {
912   return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // SymbolTableCollection
917 //===----------------------------------------------------------------------===//
918 
919 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
920                                                  StringAttr symbol) {
921   return getSymbolTable(symbolTableOp).lookup(symbol);
922 }
923 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
924                                                  SymbolRefAttr name) {
925   SmallVector<Operation *, 4> symbols;
926   if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
927     return nullptr;
928   return symbols.back();
929 }
930 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
931 /// a given SymbolRefAttr. Returns failure if any of the nested references could
932 /// not be resolved.
933 LogicalResult
934 SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
935                                       SymbolRefAttr name,
936                                       SmallVectorImpl<Operation *> &symbols) {
937   auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
938     return lookupSymbolIn(symbolTableOp, symbol);
939   };
940   return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
941 }
942 
943 /// Returns the operation registered with the given symbol name within the
944 /// closest parent operation of, or including, 'from' with the
945 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
946 /// found.
947 Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
948                                                           StringAttr symbol) {
949   Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
950   return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
951 }
952 Operation *
953 SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
954                                                SymbolRefAttr symbol) {
955   Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
956   return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
957 }
958 
959 /// Lookup, or create, a symbol table for an operation.
960 SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
961   auto it = symbolTables.try_emplace(op, nullptr);
962   if (it.second)
963     it.first->second = std::make_unique<SymbolTable>(op);
964   return *it.first->second;
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // SymbolUserMap
969 //===----------------------------------------------------------------------===//
970 
971 SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
972                              Operation *symbolTableOp)
973     : symbolTable(symbolTable) {
974   // Walk each of the symbol tables looking for discardable callgraph nodes.
975   SmallVector<Operation *> symbols;
976   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
977     for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
978       auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
979       assert(symbolUses && "expected uses to be valid");
980 
981       for (const SymbolTable::SymbolUse &use : *symbolUses) {
982         symbols.clear();
983         (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
984                                          symbols);
985         for (Operation *symbolOp : symbols)
986           symbolToUsers[symbolOp].insert(use.getUser());
987       }
988     }
989   };
990   // We just set `allSymUsesVisible` to false here because it isn't necessary
991   // for building the user map.
992   SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
993                                 walkFn);
994 }
995 
996 void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
997                                        StringAttr newSymbolName) {
998   auto it = symbolToUsers.find(symbol);
999   if (it == symbolToUsers.end())
1000     return;
1001 
1002   // Replace the uses within the users of `symbol`.
1003   for (Operation *user : it->second)
1004     (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1005 
1006   // Move the current users of `symbol` to the new symbol if it is in the
1007   // symbol table.
1008   Operation *newSymbol =
1009       symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1010   if (newSymbol != symbol) {
1011     // Transfer over the users to the new symbol.  The reference to the old one
1012     // is fetched again as the iterator is invalidated during the insertion.
1013     auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
1014     auto oldIt = symbolToUsers.find(symbol);
1015     assert(oldIt != symbolToUsers.end() && "missing old users list");
1016     if (newIt.second)
1017       newIt.first->second = std::move(oldIt->second);
1018     else
1019       newIt.first->second.set_union(oldIt->second);
1020     symbolToUsers.erase(oldIt);
1021   }
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Visibility parsing implementation.
1026 //===----------------------------------------------------------------------===//
1027 
1028 ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
1029                                                  NamedAttrList &attrs) {
1030   StringRef visibility;
1031   if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1032     return failure();
1033 
1034   StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1035   attrs.push_back(parser.getBuilder().getNamedAttr(
1036       SymbolTable::getVisibilityAttrName(), visibilityAttr));
1037   return success();
1038 }
1039 
1040 //===----------------------------------------------------------------------===//
1041 // Symbol Interfaces
1042 //===----------------------------------------------------------------------===//
1043 
1044 /// Include the generated symbol interfaces.
1045 #include "mlir/IR/SymbolInterfaces.cpp.inc"
1046