1927b7074SRiver Riddle //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===// 2927b7074SRiver Riddle // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6927b7074SRiver Riddle // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 8927b7074SRiver Riddle 9927b7074SRiver Riddle #include "mlir/IR/SymbolTable.h" 108b5a3e46SRahul Joshi #include "mlir/IR/Builders.h" 118b5a3e46SRahul Joshi #include "mlir/IR/OpImplementation.h" 126fca03f0SRiver Riddle #include "llvm/ADT/SetVector.h" 136fca03f0SRiver Riddle #include "llvm/ADT/SmallPtrSet.h" 14ee8e8b55SRiver Riddle #include "llvm/ADT/SmallString.h" 159b92e4fbSRiver Riddle #include "llvm/ADT/StringSwitch.h" 16a1fe1f5fSKazu Hirata #include <optional> 17927b7074SRiver Riddle 18927b7074SRiver Riddle using namespace mlir; 19927b7074SRiver Riddle 20b3a6ae83SRiver Riddle /// Return true if the given operation is unknown and may potentially define a 21b3a6ae83SRiver Riddle /// symbol table. 22b3a6ae83SRiver Riddle static bool isPotentiallyUnknownSymbolTable(Operation *op) { 2373547b08SRiver Riddle return op->getNumRegions() == 1 && !op->getDialect(); 24b3a6ae83SRiver Riddle } 25b3a6ae83SRiver Riddle 2641d4aa7dSChris Lattner /// Returns the string name of the given symbol, or null if this is not a 276fca03f0SRiver Riddle /// symbol. 2841d4aa7dSChris Lattner static StringAttr getNameIfSymbol(Operation *op) { 2941d4aa7dSChris Lattner return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 306fca03f0SRiver Riddle } 31195730a6SRiver Riddle static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) { 3241d4aa7dSChris Lattner return op->getAttrOfType<StringAttr>(symbolAttrNameId); 33eda450bbSRiver Riddle } 346fca03f0SRiver Riddle 356fca03f0SRiver Riddle /// Computes the nested symbol reference attribute for the symbol 'symbolName' 366fca03f0SRiver Riddle /// that are usable within the symbol table operations from 'symbol' as far up 376fca03f0SRiver Riddle /// to the given operation 'within', where 'within' is an ancestor of 'symbol'. 386fca03f0SRiver Riddle /// Returns success if all references up to 'within' could be computed. 396fca03f0SRiver Riddle static LogicalResult 4041d4aa7dSChris Lattner collectValidReferencesFor(Operation *symbol, StringAttr symbolName, 416fca03f0SRiver Riddle Operation *within, 426fca03f0SRiver Riddle SmallVectorImpl<SymbolRefAttr> &results) { 436fca03f0SRiver Riddle assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); 446fca03f0SRiver Riddle MLIRContext *ctx = symbol->getContext(); 456fca03f0SRiver Riddle 4641d4aa7dSChris Lattner auto leafRef = FlatSymbolRefAttr::get(symbolName); 476fca03f0SRiver Riddle results.push_back(leafRef); 486fca03f0SRiver Riddle 496fca03f0SRiver Riddle // Early exit for when 'within' is the parent of 'symbol'. 506fca03f0SRiver Riddle Operation *symbolTableOp = symbol->getParentOp(); 516fca03f0SRiver Riddle if (within == symbolTableOp) 526fca03f0SRiver Riddle return success(); 536fca03f0SRiver Riddle 546fca03f0SRiver Riddle // Collect references until 'symbolTableOp' reaches 'within'. 556fca03f0SRiver Riddle SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef); 56195730a6SRiver Riddle StringAttr symbolNameId = 57195730a6SRiver Riddle StringAttr::get(ctx, SymbolTable::getSymbolAttrName()); 586fca03f0SRiver Riddle do { 596fca03f0SRiver Riddle // Each parent of 'symbol' should define a symbol table. 606fca03f0SRiver Riddle if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) 616fca03f0SRiver Riddle return failure(); 626fca03f0SRiver Riddle // Each parent of 'symbol' should also be a symbol. 6341d4aa7dSChris Lattner StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId); 646fca03f0SRiver Riddle if (!symbolTableName) 656fca03f0SRiver Riddle return failure(); 6641d4aa7dSChris Lattner results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs)); 676fca03f0SRiver Riddle 686fca03f0SRiver Riddle symbolTableOp = symbolTableOp->getParentOp(); 696fca03f0SRiver Riddle if (symbolTableOp == within) 706fca03f0SRiver Riddle break; 716fca03f0SRiver Riddle nestedRefs.insert(nestedRefs.begin(), 7241d4aa7dSChris Lattner FlatSymbolRefAttr::get(symbolTableName)); 736fca03f0SRiver Riddle } while (true); 746fca03f0SRiver Riddle return success(); 756fca03f0SRiver Riddle } 766fca03f0SRiver Riddle 7771eeb5ecSRiver Riddle /// Walk all of the operations within the given set of regions, without 7871eeb5ecSRiver Riddle /// traversing into any nested symbol tables. Stops walking if the result of the 7971eeb5ecSRiver Riddle /// callback is anything other than `WalkResult::advance`. 800a81ace0SKazu Hirata static std::optional<WalkResult> 8171eeb5ecSRiver Riddle walkSymbolTable(MutableArrayRef<Region> regions, 820a81ace0SKazu Hirata function_ref<std::optional<WalkResult>(Operation *)> callback) { 8371eeb5ecSRiver Riddle SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions)); 8471eeb5ecSRiver Riddle while (!worklist.empty()) { 8571eeb5ecSRiver Riddle for (Operation &op : worklist.pop_back_val()->getOps()) { 860a81ace0SKazu Hirata std::optional<WalkResult> result = callback(&op); 8771eeb5ecSRiver Riddle if (result != WalkResult::advance()) 8871eeb5ecSRiver Riddle return result; 8971eeb5ecSRiver Riddle 9071eeb5ecSRiver Riddle // If this op defines a new symbol table scope, we can't traverse. Any 9171eeb5ecSRiver Riddle // symbol references nested within 'op' are different semantically. 9271eeb5ecSRiver Riddle if (!op.hasTrait<OpTrait::SymbolTable>()) { 9371eeb5ecSRiver Riddle for (Region ®ion : op.getRegions()) 9471eeb5ecSRiver Riddle worklist.push_back(®ion); 9571eeb5ecSRiver Riddle } 9671eeb5ecSRiver Riddle } 9771eeb5ecSRiver Riddle } 9871eeb5ecSRiver Riddle return WalkResult::advance(); 9971eeb5ecSRiver Riddle } 10071eeb5ecSRiver Riddle 10101eedbc7SRiver Riddle /// Walk all of the operations nested under, and including, the given operation, 10201eedbc7SRiver Riddle /// without traversing into any nested symbol tables. Stops walking if the 10301eedbc7SRiver Riddle /// result of the callback is anything other than `WalkResult::advance`. 1040a81ace0SKazu Hirata static std::optional<WalkResult> 10501eedbc7SRiver Riddle walkSymbolTable(Operation *op, 1060a81ace0SKazu Hirata function_ref<std::optional<WalkResult>(Operation *)> callback) { 1070a81ace0SKazu Hirata std::optional<WalkResult> result = callback(op); 10801eedbc7SRiver Riddle if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>()) 10901eedbc7SRiver Riddle return result; 11001eedbc7SRiver Riddle return walkSymbolTable(op->getRegions(), callback); 11101eedbc7SRiver Riddle } 11201eedbc7SRiver Riddle 113b3a6ae83SRiver Riddle //===----------------------------------------------------------------------===// 114b3a6ae83SRiver Riddle // SymbolTable 115b3a6ae83SRiver Riddle //===----------------------------------------------------------------------===// 116b3a6ae83SRiver Riddle 117ee8e8b55SRiver Riddle /// Build a symbol table with the symbols within the given operation. 118b8cd0c14STres Popp SymbolTable::SymbolTable(Operation *symbolTableOp) 119b8cd0c14STres Popp : symbolTableOp(symbolTableOp) { 120b8cd0c14STres Popp assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() && 121ee8e8b55SRiver Riddle "expected operation to have SymbolTable trait"); 122b8cd0c14STres Popp assert(symbolTableOp->getNumRegions() == 1 && 123ee8e8b55SRiver Riddle "expected operation to have a single region"); 124204c3b55SRiver Riddle assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) && 125b8cd0c14STres Popp "expected operation to have a single block"); 126ee8e8b55SRiver Riddle 127195730a6SRiver Riddle StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(), 128195730a6SRiver Riddle SymbolTable::getSymbolAttrName()); 129b8cd0c14STres Popp for (auto &op : symbolTableOp->getRegion(0).front()) { 13041d4aa7dSChris Lattner StringAttr name = getNameIfSymbol(&op, symbolNameId); 1316fca03f0SRiver Riddle if (!name) 132ee8e8b55SRiver Riddle continue; 133ee8e8b55SRiver Riddle 13441d4aa7dSChris Lattner auto inserted = symbolTable.insert({name, &op}); 1358c47e2edSRiver Riddle (void)inserted; 1368c47e2edSRiver Riddle assert(inserted.second && 137ee8e8b55SRiver Riddle "expected region to contain uniquely named symbol operations"); 138ee8e8b55SRiver Riddle } 1398c47e2edSRiver Riddle } 1408c47e2edSRiver Riddle 141927b7074SRiver Riddle /// Look up a symbol with the specified name, returning null if no such name 142927b7074SRiver Riddle /// exists. Names never include the @ on them. 143ee8e8b55SRiver Riddle Operation *SymbolTable::lookup(StringRef name) const { 14441d4aa7dSChris Lattner return lookup(StringAttr::get(symbolTableOp->getContext(), name)); 14541d4aa7dSChris Lattner } 14641d4aa7dSChris Lattner Operation *SymbolTable::lookup(StringAttr name) const { 147e7d594bbSRiver Riddle return symbolTable.lookup(name); 148927b7074SRiver Riddle } 149927b7074SRiver Riddle 150b6a32d94SRiver Riddle void SymbolTable::remove(Operation *op) { 151b6a32d94SRiver Riddle StringAttr name = getNameIfSymbol(op); 1526fca03f0SRiver Riddle assert(name && "expected valid 'name' attribute"); 153b6a32d94SRiver Riddle assert(op->getParentOp() == symbolTableOp && 154b8cd0c14STres Popp "expected this operation to be inside of the operation with this " 155b8cd0c14STres Popp "SymbolTable"); 156ee8e8b55SRiver Riddle 15741d4aa7dSChris Lattner auto it = symbolTable.find(name); 158b6a32d94SRiver Riddle if (it != symbolTable.end() && it->second == op) 159927b7074SRiver Riddle symbolTable.erase(it); 160b8cd0c14STres Popp } 161b6a32d94SRiver Riddle 162b6a32d94SRiver Riddle void SymbolTable::erase(Operation *symbol) { 163b6a32d94SRiver Riddle remove(symbol); 164b6a32d94SRiver Riddle symbol->erase(); 165927b7074SRiver Riddle } 166927b7074SRiver Riddle 167f43e67ccSTres Popp // TODO: Consider if this should be renamed to something like insertOrUpdate 168f43e67ccSTres Popp /// Insert a new symbol into the table and associated operation if not already 169feec2d90SAlex Zinenko /// there and rename it as necessary to avoid collisions. Return the name of 170feec2d90SAlex Zinenko /// the symbol after insertion as attribute. 171feec2d90SAlex Zinenko StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { 172f43e67ccSTres Popp // The symbol cannot be the child of another op and must be the child of the 173f43e67ccSTres Popp // symbolTableOp after this. 174f43e67ccSTres Popp // 175f43e67ccSTres Popp // TODO: consider if SymbolTable's constructor should behave the same. 176f43e67ccSTres Popp if (!symbol->getParentOp()) { 177b8cd0c14STres Popp auto &body = symbolTableOp->getRegion(0).front(); 178973ddb7dSMehdi Amini if (insertPt == Block::iterator()) { 179973ddb7dSMehdi Amini insertPt = Block::iterator(body.end()); 180973ddb7dSMehdi Amini } else { 181973ddb7dSMehdi Amini assert((insertPt == body.end() || 182973ddb7dSMehdi Amini insertPt->getParentOp() == symbolTableOp) && 183b8cd0c14STres Popp "expected insertPt to be in the associated module operation"); 184973ddb7dSMehdi Amini } 185973ddb7dSMehdi Amini // Insert before the terminator, if any. 186973ddb7dSMehdi Amini if (insertPt == Block::iterator(body.end()) && !body.empty() && 187973ddb7dSMehdi Amini std::prev(body.end())->hasTrait<OpTrait::IsTerminator>()) 188973ddb7dSMehdi Amini insertPt = std::prev(body.end()); 189b8cd0c14STres Popp 190b8cd0c14STres Popp body.getOperations().insert(insertPt, symbol); 191f43e67ccSTres Popp } 192f43e67ccSTres Popp assert(symbol->getParentOp() == symbolTableOp && 193f43e67ccSTres Popp "symbol is already inserted in another op"); 194b8cd0c14STres Popp 195927b7074SRiver Riddle // Add this symbol to the symbol table, uniquing the name if a conflict is 196927b7074SRiver Riddle // detected. 19741d4aa7dSChris Lattner StringAttr name = getSymbolName(symbol); 1986fca03f0SRiver Riddle if (symbolTable.insert({name, symbol}).second) 199feec2d90SAlex Zinenko return name; 200f43e67ccSTres Popp // If the symbol was already in the table, also return. 201f43e67ccSTres Popp if (symbolTable.lookup(name) == symbol) 202feec2d90SAlex Zinenko return name; 203927b7074SRiver Riddle 20441d4aa7dSChris Lattner MLIRContext *context = symbol->getContext(); 205ea84897bSGuray Ozen SmallString<128> nameBuffer = generateSymbolName<128>( 206ea84897bSGuray Ozen name.getValue(), 207ea84897bSGuray Ozen [&](StringRef candidate) { 208ea84897bSGuray Ozen return !symbolTable 209ea84897bSGuray Ozen .insert({StringAttr::get(context, candidate), symbol}) 210ea84897bSGuray Ozen .second; 211ea84897bSGuray Ozen }, 212ea84897bSGuray Ozen uniquingCounter); 2136fca03f0SRiver Riddle setSymbolName(symbol, nameBuffer); 214feec2d90SAlex Zinenko return getSymbolName(symbol); 2156fca03f0SRiver Riddle } 2166fca03f0SRiver Riddle 21778768994SIngo Müller LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) { 21878768994SIngo Müller Operation *op = lookup(from); 21978768994SIngo Müller return rename(op, to); 22078768994SIngo Müller } 22178768994SIngo Müller 22278768994SIngo Müller LogicalResult SymbolTable::rename(Operation *op, StringAttr to) { 22378768994SIngo Müller StringAttr from = getNameIfSymbol(op); 224e1f90b50SChristian Sigg (void)from; 22578768994SIngo Müller 22678768994SIngo Müller assert(from && "expected valid 'name' attribute"); 22778768994SIngo Müller assert(op->getParentOp() == symbolTableOp && 22878768994SIngo Müller "expected this operation to be inside of the operation with this " 22978768994SIngo Müller "SymbolTable"); 23078768994SIngo Müller assert(lookup(from) == op && "current name does not resolve to op"); 23178768994SIngo Müller assert(lookup(to) == nullptr && "new name already exists"); 23278768994SIngo Müller 23378768994SIngo Müller if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp()))) 23478768994SIngo Müller return failure(); 23578768994SIngo Müller 23678768994SIngo Müller // Remove op with old name, change name, add with new name. The order is 23778768994SIngo Müller // important here due to how `remove` and `insert` rely on the op name. 23878768994SIngo Müller remove(op); 23978768994SIngo Müller setSymbolName(op, to); 24078768994SIngo Müller insert(op); 24178768994SIngo Müller 24278768994SIngo Müller assert(lookup(to) == op && "new name does not resolve to renamed op"); 24378768994SIngo Müller assert(lookup(from) == nullptr && "old name still exists"); 24478768994SIngo Müller 24578768994SIngo Müller return success(); 24678768994SIngo Müller } 24778768994SIngo Müller 24878768994SIngo Müller LogicalResult SymbolTable::rename(StringAttr from, StringRef to) { 24978768994SIngo Müller auto toAttr = StringAttr::get(getOp()->getContext(), to); 25078768994SIngo Müller return rename(from, toAttr); 25178768994SIngo Müller } 25278768994SIngo Müller 25378768994SIngo Müller LogicalResult SymbolTable::rename(Operation *op, StringRef to) { 25478768994SIngo Müller auto toAttr = StringAttr::get(getOp()->getContext(), to); 25578768994SIngo Müller return rename(op, toAttr); 25678768994SIngo Müller } 25778768994SIngo Müller 25878768994SIngo Müller FailureOr<StringAttr> 25978768994SIngo Müller SymbolTable::renameToUnique(StringAttr oldName, 26078768994SIngo Müller ArrayRef<SymbolTable *> others) { 26178768994SIngo Müller 26278768994SIngo Müller // Determine new name that is unique in all symbol tables. 26378768994SIngo Müller StringAttr newName; 26478768994SIngo Müller { 26578768994SIngo Müller MLIRContext *context = oldName.getContext(); 26678768994SIngo Müller SmallString<64> prefix = oldName.getValue(); 26778768994SIngo Müller int uniqueId = 0; 26878768994SIngo Müller prefix.push_back('_'); 26978768994SIngo Müller while (true) { 27078768994SIngo Müller newName = StringAttr::get(context, prefix + Twine(uniqueId++)); 27178768994SIngo Müller auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); }; 27278768994SIngo Müller if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) { 27378768994SIngo Müller break; 27478768994SIngo Müller } 27578768994SIngo Müller } 27678768994SIngo Müller } 27778768994SIngo Müller 27878768994SIngo Müller // Apply renaming. 27978768994SIngo Müller if (failed(rename(oldName, newName))) 28078768994SIngo Müller return failure(); 28178768994SIngo Müller return newName; 28278768994SIngo Müller } 28378768994SIngo Müller 28478768994SIngo Müller FailureOr<StringAttr> 28578768994SIngo Müller SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) { 28678768994SIngo Müller StringAttr from = getNameIfSymbol(op); 28778768994SIngo Müller assert(from && "expected valid 'name' attribute"); 28878768994SIngo Müller return renameToUnique(from, others); 28978768994SIngo Müller } 29078768994SIngo Müller 2916fca03f0SRiver Riddle /// Returns the name of the given symbol operation. 29241d4aa7dSChris Lattner StringAttr SymbolTable::getSymbolName(Operation *symbol) { 29341d4aa7dSChris Lattner StringAttr name = getNameIfSymbol(symbol); 2946fca03f0SRiver Riddle assert(name && "expected valid symbol name"); 29541d4aa7dSChris Lattner return name; 2966fca03f0SRiver Riddle } 29741d4aa7dSChris Lattner 2986fca03f0SRiver Riddle /// Sets the name of the given symbol operation. 29941d4aa7dSChris Lattner void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) { 30041d4aa7dSChris Lattner symbol->setAttr(getSymbolAttrName(), name); 301ee8e8b55SRiver Riddle } 302ee8e8b55SRiver Riddle 3039b92e4fbSRiver Riddle /// Returns the visibility of the given symbol operation. 3049b92e4fbSRiver Riddle SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) { 3059b92e4fbSRiver Riddle // If the attribute doesn't exist, assume public. 3069b92e4fbSRiver Riddle StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName()); 3079b92e4fbSRiver Riddle if (!vis) 3089b92e4fbSRiver Riddle return Visibility::Public; 3099b92e4fbSRiver Riddle 3109b92e4fbSRiver Riddle // Otherwise, switch on the string value. 311cc83dc19SChristian Sigg return StringSwitch<Visibility>(vis.getValue()) 3129b92e4fbSRiver Riddle .Case("private", Visibility::Private) 3139b92e4fbSRiver Riddle .Case("nested", Visibility::Nested) 3149b92e4fbSRiver Riddle .Case("public", Visibility::Public); 3159b92e4fbSRiver Riddle } 3169b92e4fbSRiver Riddle /// Sets the visibility of the given symbol operation. 3179b92e4fbSRiver Riddle void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { 3189b92e4fbSRiver Riddle MLIRContext *ctx = symbol->getContext(); 3199b92e4fbSRiver Riddle 3209b92e4fbSRiver Riddle // If the visibility is public, just drop the attribute as this is the 3219b92e4fbSRiver Riddle // default. 3229b92e4fbSRiver Riddle if (vis == Visibility::Public) { 323195730a6SRiver Riddle symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName())); 3249b92e4fbSRiver Riddle return; 3259b92e4fbSRiver Riddle } 3269b92e4fbSRiver Riddle 3279b92e4fbSRiver Riddle // Otherwise, update the attribute. 3289b92e4fbSRiver Riddle assert((vis == Visibility::Private || vis == Visibility::Nested) && 3299b92e4fbSRiver Riddle "unknown symbol visibility kind"); 3309b92e4fbSRiver Riddle 3319b92e4fbSRiver Riddle StringRef visName = vis == Visibility::Private ? "private" : "nested"; 332c2c83e97STres Popp symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); 3339b92e4fbSRiver Riddle } 3349b92e4fbSRiver Riddle 33529e411b3SLei Zhang /// Returns the nearest symbol table from a given operation `from`. Returns 33629e411b3SLei Zhang /// nullptr if no valid parent symbol table could be found. 33729e411b3SLei Zhang Operation *SymbolTable::getNearestSymbolTable(Operation *from) { 33829e411b3SLei Zhang assert(from && "expected valid operation"); 33929e411b3SLei Zhang if (isPotentiallyUnknownSymbolTable(from)) 34029e411b3SLei Zhang return nullptr; 34129e411b3SLei Zhang 34229e411b3SLei Zhang while (!from->hasTrait<OpTrait::SymbolTable>()) { 34329e411b3SLei Zhang from = from->getParentOp(); 34429e411b3SLei Zhang 34529e411b3SLei Zhang // Check that this is a valid op and isn't an unknown symbol table. 34629e411b3SLei Zhang if (!from || isPotentiallyUnknownSymbolTable(from)) 34729e411b3SLei Zhang return nullptr; 34829e411b3SLei Zhang } 34929e411b3SLei Zhang return from; 35029e411b3SLei Zhang } 35129e411b3SLei Zhang 352a90151d6SRiver Riddle /// Walks all symbol table operations nested within, and including, `op`. For 353a90151d6SRiver Riddle /// each symbol table operation, the provided callback is invoked with the op 354a90151d6SRiver Riddle /// and a boolean signifying if the symbols within that symbol table can be 355a90151d6SRiver Riddle /// treated as if all uses are visible. `allSymUsesVisible` identifies whether 356a90151d6SRiver Riddle /// all of the symbol uses of symbols within `op` are visible. 357a90151d6SRiver Riddle void SymbolTable::walkSymbolTables( 358a90151d6SRiver Riddle Operation *op, bool allSymUsesVisible, 359a90151d6SRiver Riddle function_ref<void(Operation *, bool)> callback) { 360a90151d6SRiver Riddle bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>(); 361a90151d6SRiver Riddle if (isSymbolTable) { 362a90151d6SRiver Riddle SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op); 363a90151d6SRiver Riddle allSymUsesVisible |= !symbol || symbol.isPrivate(); 364a90151d6SRiver Riddle } else { 365a90151d6SRiver Riddle // Otherwise if 'op' is not a symbol table, any nested symbols are 366a90151d6SRiver Riddle // guaranteed to be hidden. 367a90151d6SRiver Riddle allSymUsesVisible = true; 368a90151d6SRiver Riddle } 369a90151d6SRiver Riddle 370a90151d6SRiver Riddle for (Region ®ion : op->getRegions()) 371a90151d6SRiver Riddle for (Block &block : region) 372a90151d6SRiver Riddle for (Operation &nestedOp : block) 373a90151d6SRiver Riddle walkSymbolTables(&nestedOp, allSymUsesVisible, callback); 374a90151d6SRiver Riddle 375a90151d6SRiver Riddle // If 'op' had the symbol table trait, visit it after any nested symbol 376a90151d6SRiver Riddle // tables. 377a90151d6SRiver Riddle if (isSymbolTable) 378a90151d6SRiver Riddle callback(op, allSymUsesVisible); 379a90151d6SRiver Riddle } 380a90151d6SRiver Riddle 3818cb405a8SRiver Riddle /// Returns the operation registered with the given symbol name with the 3828cb405a8SRiver Riddle /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 3838cb405a8SRiver Riddle /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol 3848cb405a8SRiver Riddle /// was found. 3858cb405a8SRiver Riddle Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, 38641d4aa7dSChris Lattner StringAttr symbol) { 3878cb405a8SRiver Riddle assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); 388973ddb7dSMehdi Amini Region ®ion = symbolTableOp->getRegion(0); 389973ddb7dSMehdi Amini if (region.empty()) 390973ddb7dSMehdi Amini return nullptr; 3918cb405a8SRiver Riddle 3928cb405a8SRiver Riddle // Look for a symbol with the given name. 393195730a6SRiver Riddle StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(), 394195730a6SRiver Riddle SymbolTable::getSymbolAttrName()); 395973ddb7dSMehdi Amini for (auto &op : region.front()) 396eda450bbSRiver Riddle if (getNameIfSymbol(&op, symbolNameId) == symbol) 3978cb405a8SRiver Riddle return &op; 3988cb405a8SRiver Riddle return nullptr; 3998cb405a8SRiver Riddle } 4006fca03f0SRiver Riddle Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, 4016fca03f0SRiver Riddle SymbolRefAttr symbol) { 402b276dec5SRiver Riddle SmallVector<Operation *, 4> resolvedSymbols; 403b276dec5SRiver Riddle if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols))) 404b276dec5SRiver Riddle return nullptr; 405b276dec5SRiver Riddle return resolvedSymbols.back(); 406b276dec5SRiver Riddle } 407b276dec5SRiver Riddle 4087bc7d0acSRiver Riddle /// Internal implementation of `lookupSymbolIn` that allows for specialized 4097bc7d0acSRiver Riddle /// implementations of the lookup function. 4107bc7d0acSRiver Riddle static LogicalResult lookupSymbolInImpl( 4117bc7d0acSRiver Riddle Operation *symbolTableOp, SymbolRefAttr symbol, 4127bc7d0acSRiver Riddle SmallVectorImpl<Operation *> &symbols, 41341d4aa7dSChris Lattner function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) { 4146fca03f0SRiver Riddle assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); 4156fca03f0SRiver Riddle 4166fca03f0SRiver Riddle // Lookup the root reference for this symbol. 4177bc7d0acSRiver Riddle symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference()); 4186fca03f0SRiver Riddle if (!symbolTableOp) 419b276dec5SRiver Riddle return failure(); 420b276dec5SRiver Riddle symbols.push_back(symbolTableOp); 4216fca03f0SRiver Riddle 4226fca03f0SRiver Riddle // If there are no nested references, just return the root symbol directly. 4236fca03f0SRiver Riddle ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences(); 4246fca03f0SRiver Riddle if (nestedRefs.empty()) 425b276dec5SRiver Riddle return success(); 4266fca03f0SRiver Riddle 4276fca03f0SRiver Riddle // Verify that the root is also a symbol table. 4286fca03f0SRiver Riddle if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) 429b276dec5SRiver Riddle return failure(); 4306fca03f0SRiver Riddle 4316fca03f0SRiver Riddle // Otherwise, lookup each of the nested non-leaf references and ensure that 4326fca03f0SRiver Riddle // each corresponds to a valid symbol table. 4336fca03f0SRiver Riddle for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { 43441d4aa7dSChris Lattner symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr()); 4356fca03f0SRiver Riddle if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) 436b276dec5SRiver Riddle return failure(); 437b276dec5SRiver Riddle symbols.push_back(symbolTableOp); 4386fca03f0SRiver Riddle } 4397bc7d0acSRiver Riddle symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference())); 440b276dec5SRiver Riddle return success(symbols.back()); 4416fca03f0SRiver Riddle } 4428cb405a8SRiver Riddle 4437bc7d0acSRiver Riddle LogicalResult 4447bc7d0acSRiver Riddle SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, 4457bc7d0acSRiver Riddle SmallVectorImpl<Operation *> &symbols) { 44641d4aa7dSChris Lattner auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) { 4477bc7d0acSRiver Riddle return lookupSymbolIn(symbolTableOp, symbol); 4487bc7d0acSRiver Riddle }; 4497bc7d0acSRiver Riddle return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn); 4507bc7d0acSRiver Riddle } 4517bc7d0acSRiver Riddle 4528cb405a8SRiver Riddle /// Returns the operation registered with the given symbol name within the 4538cb405a8SRiver Riddle /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns 4548cb405a8SRiver Riddle /// nullptr if no valid symbol was found. 4558cb405a8SRiver Riddle Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, 45641d4aa7dSChris Lattner StringAttr symbol) { 4576fca03f0SRiver Riddle Operation *symbolTableOp = getNearestSymbolTable(from); 4586fca03f0SRiver Riddle return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; 459b3a6ae83SRiver Riddle } 4606fca03f0SRiver Riddle Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, 4616fca03f0SRiver Riddle SymbolRefAttr symbol) { 4626fca03f0SRiver Riddle Operation *symbolTableOp = getNearestSymbolTable(from); 4636fca03f0SRiver Riddle return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; 4648cb405a8SRiver Riddle } 4658cb405a8SRiver Riddle 466f492c359SRiver Riddle raw_ostream &mlir::operator<<(raw_ostream &os, 467f492c359SRiver Riddle SymbolTable::Visibility visibility) { 468f492c359SRiver Riddle switch (visibility) { 469f492c359SRiver Riddle case SymbolTable::Visibility::Public: 470f492c359SRiver Riddle return os << "public"; 471f492c359SRiver Riddle case SymbolTable::Visibility::Private: 472f492c359SRiver Riddle return os << "private"; 473f492c359SRiver Riddle case SymbolTable::Visibility::Nested: 474f492c359SRiver Riddle return os << "nested"; 475f492c359SRiver Riddle } 4762c811548SMehdi Amini llvm_unreachable("Unexpected visibility"); 477f492c359SRiver Riddle } 478f492c359SRiver Riddle 479ee8e8b55SRiver Riddle //===----------------------------------------------------------------------===// 480ee8e8b55SRiver Riddle // SymbolTable Trait Types 481ee8e8b55SRiver Riddle //===----------------------------------------------------------------------===// 482ee8e8b55SRiver Riddle 4837c221a7dSRiver Riddle LogicalResult detail::verifySymbolTable(Operation *op) { 484ee8e8b55SRiver Riddle if (op->getNumRegions() != 1) 485ee8e8b55SRiver Riddle return op->emitOpError() 486ee8e8b55SRiver Riddle << "Operations with a 'SymbolTable' must have exactly one region"; 487204c3b55SRiver Riddle if (!llvm::hasSingleElement(op->getRegion(0))) 488b8cd0c14STres Popp return op->emitOpError() 489b8cd0c14STres Popp << "Operations with a 'SymbolTable' must have exactly one block"; 490ee8e8b55SRiver Riddle 4918bfedb3cSKazuaki Ishizaki // Check that all symbols are uniquely named within child regions. 4926fca03f0SRiver Riddle DenseMap<Attribute, Location> nameToOrigLoc; 493ee8e8b55SRiver Riddle for (auto &block : op->getRegion(0)) { 494ee8e8b55SRiver Riddle for (auto &op : block) { 495ee8e8b55SRiver Riddle // Check for a symbol name attribute. 496ee8e8b55SRiver Riddle auto nameAttr = 497ee8e8b55SRiver Riddle op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()); 498ee8e8b55SRiver Riddle if (!nameAttr) 499ee8e8b55SRiver Riddle continue; 500ee8e8b55SRiver Riddle 501ee8e8b55SRiver Riddle // Try to insert this symbol into the table. 5026fca03f0SRiver Riddle auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc()); 503ee8e8b55SRiver Riddle if (!it.second) 504ee8e8b55SRiver Riddle return op.emitError() 505ee8e8b55SRiver Riddle .append("redefinition of symbol named '", nameAttr.getValue(), "'") 506ee8e8b55SRiver Riddle .attachNote(it.first->second) 507ee8e8b55SRiver Riddle .append("see existing symbol definition here"); 508ee8e8b55SRiver Riddle } 509ee8e8b55SRiver Riddle } 51071eeb5ecSRiver Riddle 51171eeb5ecSRiver Riddle // Verify any nested symbol user operations. 51271eeb5ecSRiver Riddle SymbolTableCollection symbolTable; 5130a81ace0SKazu Hirata auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> { 51471eeb5ecSRiver Riddle if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op)) 51571eeb5ecSRiver Riddle return WalkResult(user.verifySymbolUses(symbolTable)); 51671eeb5ecSRiver Riddle return WalkResult::advance(); 51771eeb5ecSRiver Riddle }; 51871eeb5ecSRiver Riddle 5190a81ace0SKazu Hirata std::optional<WalkResult> result = 52071eeb5ecSRiver Riddle walkSymbolTable(op->getRegions(), verifySymbolUserFn); 52171eeb5ecSRiver Riddle return success(result && !result->wasInterrupted()); 522927b7074SRiver Riddle } 523ac91e673SRiver Riddle 5247c221a7dSRiver Riddle LogicalResult detail::verifySymbol(Operation *op) { 5259b92e4fbSRiver Riddle // Verify the name attribute. 5269ac459e8SRiver Riddle if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName())) 5279ac459e8SRiver Riddle return op->emitOpError() << "requires string attribute '" 5289ac459e8SRiver Riddle << mlir::SymbolTable::getSymbolAttrName() << "'"; 5299b92e4fbSRiver Riddle 5309b92e4fbSRiver Riddle // Verify the visibility attribute. 5319b92e4fbSRiver Riddle if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) { 532c1fa60b4STres Popp StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis); 5339b92e4fbSRiver Riddle if (!visStrAttr) 5349b92e4fbSRiver Riddle return op->emitOpError() << "requires visibility attribute '" 5359b92e4fbSRiver Riddle << mlir::SymbolTable::getVisibilityAttrName() 5369b92e4fbSRiver Riddle << "' to be a string attribute, but got " << vis; 5379b92e4fbSRiver Riddle 5389b92e4fbSRiver Riddle if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"}, 5399b92e4fbSRiver Riddle visStrAttr.getValue())) 5409b92e4fbSRiver Riddle return op->emitOpError() 5419b92e4fbSRiver Riddle << "visibility expected to be one of [\"public\", \"private\", " 5429b92e4fbSRiver Riddle "\"nested\"], but got " 5439b92e4fbSRiver Riddle << visStrAttr; 5449b92e4fbSRiver Riddle } 5459ac459e8SRiver Riddle return success(); 5469ac459e8SRiver Riddle } 5479ac459e8SRiver Riddle 548ac91e673SRiver Riddle //===----------------------------------------------------------------------===// 549ef43b565SRiver Riddle // Symbol Use Lists 550ac91e673SRiver Riddle //===----------------------------------------------------------------------===// 551ac91e673SRiver Riddle 552ac91e673SRiver Riddle /// Walk all of the symbol references within the given operation, invoking the 55301eedbc7SRiver Riddle /// provided callback for each found use. The callbacks takes the use of the 55401eedbc7SRiver Riddle /// symbol. 55501eedbc7SRiver Riddle static WalkResult 55601eedbc7SRiver Riddle walkSymbolRefs(Operation *op, 55701eedbc7SRiver Riddle function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { 55803d136cfSRiver Riddle return op->getAttrDictionary().walk<WalkOrder::PreOrder>( 55903d136cfSRiver Riddle [&](SymbolRefAttr symbolRef) { 56001eedbc7SRiver Riddle if (callback({op, symbolRef}).wasInterrupted()) 56101eedbc7SRiver Riddle return WalkResult::interrupt(); 56201eedbc7SRiver Riddle 56303d136cfSRiver Riddle // Don't walk nested references. 56403d136cfSRiver Riddle return WalkResult::skip(); 56503d136cfSRiver Riddle }); 566ac91e673SRiver Riddle } 567ac91e673SRiver Riddle 568ac91e673SRiver Riddle /// Walk all of the uses, for any symbol, that are nested within the given 569ab9e5598SRiver Riddle /// regions, invoking the provided callback for each. This does not traverse 570ab9e5598SRiver Riddle /// into any nested symbol tables. 5710a81ace0SKazu Hirata static std::optional<WalkResult> 57201eedbc7SRiver Riddle walkSymbolUses(MutableArrayRef<Region> regions, 57301eedbc7SRiver Riddle function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { 5740a81ace0SKazu Hirata return walkSymbolTable(regions, 5750a81ace0SKazu Hirata [&](Operation *op) -> std::optional<WalkResult> { 5760a81ace0SKazu Hirata // Check that this isn't a potentially unknown symbol 5770a81ace0SKazu Hirata // table. 57871eeb5ecSRiver Riddle if (isPotentiallyUnknownSymbolTable(op)) 5791a36588eSKazu Hirata return std::nullopt; 580b3a6ae83SRiver Riddle 58171eeb5ecSRiver Riddle return walkSymbolRefs(op, callback); 58271eeb5ecSRiver Riddle }); 583ac91e673SRiver Riddle } 584ab9e5598SRiver Riddle /// Walk all of the uses, for any symbol, that are nested within the given 5855aacce3dSKazuaki Ishizaki /// operation 'from', invoking the provided callback for each. This does not 586ab9e5598SRiver Riddle /// traverse into any nested symbol tables. 5870a81ace0SKazu Hirata static std::optional<WalkResult> 58801eedbc7SRiver Riddle walkSymbolUses(Operation *from, 58901eedbc7SRiver Riddle function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { 590ab9e5598SRiver Riddle // If this operation has regions, and it, as well as its dialect, isn't 591ab9e5598SRiver Riddle // registered then conservatively fail. The operation may define a 592ab9e5598SRiver Riddle // symbol table, so we can't opaquely know if we should traverse to find 593ab9e5598SRiver Riddle // nested uses. 594ab9e5598SRiver Riddle if (isPotentiallyUnknownSymbolTable(from)) 5951a36588eSKazu Hirata return std::nullopt; 596ac91e673SRiver Riddle 597ab9e5598SRiver Riddle // Walk the uses on this operation. 598ab9e5598SRiver Riddle if (walkSymbolRefs(from, callback).wasInterrupted()) 599ab9e5598SRiver Riddle return WalkResult::interrupt(); 600ab9e5598SRiver Riddle 601ab9e5598SRiver Riddle // Only recurse if this operation is not a symbol table. A symbol table 602ab9e5598SRiver Riddle // defines a new scope, so we can't walk the attributes from within the symbol 603ab9e5598SRiver Riddle // table op. 604ab9e5598SRiver Riddle if (!from->hasTrait<OpTrait::SymbolTable>()) 605ab9e5598SRiver Riddle return walkSymbolUses(from->getRegions(), callback); 606ab9e5598SRiver Riddle return WalkResult::advance(); 607ab9e5598SRiver Riddle } 608ab9e5598SRiver Riddle 609ab9e5598SRiver Riddle namespace { 610ab9e5598SRiver Riddle /// This class represents a single symbol scope. A symbol scope represents the 611ab9e5598SRiver Riddle /// set of operations nested within a symbol table that may reference symbols 612ab9e5598SRiver Riddle /// within that table. A symbol scope does not contain the symbol table 613ab9e5598SRiver Riddle /// operation itself, just its contained operations. A scope ends at leaf 614ab9e5598SRiver Riddle /// operations or another symbol table operation. 615ab9e5598SRiver Riddle struct SymbolScope { 616ab9e5598SRiver Riddle /// Walk the symbol uses within this scope, invoking the given callback. 617ab9e5598SRiver Riddle /// This variant is used when the callback type matches that expected by 618ab9e5598SRiver Riddle /// 'walkSymbolUses'. 619ab9e5598SRiver Riddle template <typename CallbackT, 620de49627dSKazu Hirata std::enable_if_t<!std::is_same< 6218cbe371cSRiver Riddle typename llvm::function_traits<CallbackT>::result_t, 6228cbe371cSRiver Riddle void>::value> * = nullptr> 6230a81ace0SKazu Hirata std::optional<WalkResult> walk(CallbackT cback) { 62468f58812STres Popp if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) 625ab9e5598SRiver Riddle return walkSymbolUses(*region, cback); 626*fcb1591bSKazu Hirata return walkSymbolUses(cast<Operation *>(limit), cback); 627ab9e5598SRiver Riddle } 628ab9e5598SRiver Riddle /// This variant is used when the callback type matches a stripped down type: 629ab9e5598SRiver Riddle /// void(SymbolTable::SymbolUse use) 630ab9e5598SRiver Riddle template <typename CallbackT, 631de49627dSKazu Hirata std::enable_if_t<std::is_same< 6328cbe371cSRiver Riddle typename llvm::function_traits<CallbackT>::result_t, 6338cbe371cSRiver Riddle void>::value> * = nullptr> 6340a81ace0SKazu Hirata std::optional<WalkResult> walk(CallbackT cback) { 63501eedbc7SRiver Riddle return walk([=](SymbolTable::SymbolUse use) { 636ab9e5598SRiver Riddle return cback(use), WalkResult::advance(); 637ab9e5598SRiver Riddle }); 638ab9e5598SRiver Riddle } 639ab9e5598SRiver Riddle 64001eedbc7SRiver Riddle /// Walk all of the operations nested under the current scope without 64101eedbc7SRiver Riddle /// traversing into any nested symbol tables. 64201eedbc7SRiver Riddle template <typename CallbackT> 6430a81ace0SKazu Hirata std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) { 64468f58812STres Popp if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) 64501eedbc7SRiver Riddle return ::walkSymbolTable(*region, cback); 646*fcb1591bSKazu Hirata return ::walkSymbolTable(cast<Operation *>(limit), cback); 64701eedbc7SRiver Riddle } 64801eedbc7SRiver Riddle 649ab9e5598SRiver Riddle /// The representation of the symbol within this scope. 650ab9e5598SRiver Riddle SymbolRefAttr symbol; 651ab9e5598SRiver Riddle 652ab9e5598SRiver Riddle /// The IR unit representing this scope. 653ab9e5598SRiver Riddle llvm::PointerUnion<Operation *, Region *> limit; 654ab9e5598SRiver Riddle }; 655be0a7e9fSMehdi Amini } // namespace 656ab9e5598SRiver Riddle 657ab9e5598SRiver Riddle /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'. 658ab9e5598SRiver Riddle static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, 659ab9e5598SRiver Riddle Operation *limit) { 66041d4aa7dSChris Lattner StringAttr symName = SymbolTable::getSymbolName(symbol); 6616fca03f0SRiver Riddle assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit); 6626fca03f0SRiver Riddle 6636fca03f0SRiver Riddle // Compute the ancestors of 'limit'. 6644efb7754SRiver Riddle SetVector<Operation *, SmallVector<Operation *, 4>, 6656fca03f0SRiver Riddle SmallPtrSet<Operation *, 4>> 6666fca03f0SRiver Riddle limitAncestors; 6676fca03f0SRiver Riddle Operation *limitAncestor = limit; 6686fca03f0SRiver Riddle do { 6696fca03f0SRiver Riddle // Check to see if 'symbol' is an ancestor of 'limit'. 6706fca03f0SRiver Riddle if (limitAncestor == symbol) { 6716fca03f0SRiver Riddle // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr 6726fca03f0SRiver Riddle // doesn't support parent references. 673ab9e5598SRiver Riddle if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == 674ab9e5598SRiver Riddle symbol->getParentOp()) 67541d4aa7dSChris Lattner return {{SymbolRefAttr::get(symName), limit}}; 676ab9e5598SRiver Riddle return {}; 6776fca03f0SRiver Riddle } 6786fca03f0SRiver Riddle 6796fca03f0SRiver Riddle limitAncestors.insert(limitAncestor); 6806fca03f0SRiver Riddle } while ((limitAncestor = limitAncestor->getParentOp())); 6816fca03f0SRiver Riddle 6826fca03f0SRiver Riddle // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'. 6836fca03f0SRiver Riddle Operation *commonAncestor = symbol->getParentOp(); 6846fca03f0SRiver Riddle do { 6856fca03f0SRiver Riddle if (limitAncestors.count(commonAncestor)) 6866fca03f0SRiver Riddle break; 6876fca03f0SRiver Riddle } while ((commonAncestor = commonAncestor->getParentOp())); 6886fca03f0SRiver Riddle assert(commonAncestor && "'limit' and 'symbol' have no common ancestor"); 6896fca03f0SRiver Riddle 6906fca03f0SRiver Riddle // Compute the set of valid nested references for 'symbol' as far up to the 6916fca03f0SRiver Riddle // common ancestor as possible. 6926fca03f0SRiver Riddle SmallVector<SymbolRefAttr, 2> references; 693ab9e5598SRiver Riddle bool collectedAllReferences = succeeded( 694ab9e5598SRiver Riddle collectValidReferencesFor(symbol, symName, commonAncestor, references)); 6956fca03f0SRiver Riddle 6966fca03f0SRiver Riddle // Handle the case where the common ancestor is 'limit'. 6976fca03f0SRiver Riddle if (commonAncestor == limit) { 698ab9e5598SRiver Riddle SmallVector<SymbolScope, 2> scopes; 699ab9e5598SRiver Riddle 7006fca03f0SRiver Riddle // Walk each of the ancestors of 'symbol', calling the compute function for 7016fca03f0SRiver Riddle // each one. 7026fca03f0SRiver Riddle Operation *limitIt = symbol->getParentOp(); 7036fca03f0SRiver Riddle for (size_t i = 0, e = references.size(); i != e; 7046fca03f0SRiver Riddle ++i, limitIt = limitIt->getParentOp()) { 705ab9e5598SRiver Riddle assert(limitIt->hasTrait<OpTrait::SymbolTable>()); 706ab9e5598SRiver Riddle scopes.push_back({references[i], &limitIt->getRegion(0)}); 7076fca03f0SRiver Riddle } 708ab9e5598SRiver Riddle return scopes; 7096fca03f0SRiver Riddle } 7106fca03f0SRiver Riddle 7116fca03f0SRiver Riddle // Otherwise, we just need the symbol reference for 'symbol' that will be 7126fca03f0SRiver Riddle // used within 'limit'. This is the last reference in the list we computed 7136fca03f0SRiver Riddle // above if we were able to collect all references. 7146fca03f0SRiver Riddle if (!collectedAllReferences) 715ab9e5598SRiver Riddle return {}; 716ab9e5598SRiver Riddle return {{references.back(), limit}}; 7176fca03f0SRiver Riddle } 718ab9e5598SRiver Riddle static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, 719ab9e5598SRiver Riddle Region *limit) { 720ab9e5598SRiver Riddle auto scopes = collectSymbolScopes(symbol, limit->getParentOp()); 7216fca03f0SRiver Riddle 722ab9e5598SRiver Riddle // If we collected some scopes to walk, make sure to constrain the one for 723ab9e5598SRiver Riddle // limit to the specific region requested. 724ab9e5598SRiver Riddle if (!scopes.empty()) 725ab9e5598SRiver Riddle scopes.back().limit = limit; 726ab9e5598SRiver Riddle return scopes; 727ab9e5598SRiver Riddle } 72841d4aa7dSChris Lattner static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, 72947905788SIngo Müller Region *limit) { 73041d4aa7dSChris Lattner return {{SymbolRefAttr::get(symbol), limit}}; 7316fca03f0SRiver Riddle } 7326fca03f0SRiver Riddle 73347905788SIngo Müller static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, 73447905788SIngo Müller Operation *limit) { 73547905788SIngo Müller SmallVector<SymbolScope, 1> scopes; 73647905788SIngo Müller auto symbolRef = SymbolRefAttr::get(symbol); 73747905788SIngo Müller for (auto ®ion : limit->getRegions()) 73847905788SIngo Müller scopes.push_back({symbolRef, ®ion}); 73947905788SIngo Müller return scopes; 74047905788SIngo Müller } 74147905788SIngo Müller 7426fca03f0SRiver Riddle /// Returns true if the given reference 'SubRef' is a sub reference of the 7436fca03f0SRiver Riddle /// reference 'ref', i.e. 'ref' is a further qualified reference. 7446fca03f0SRiver Riddle static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { 7456fca03f0SRiver Riddle if (ref == subRef) 7466fca03f0SRiver Riddle return true; 7476fca03f0SRiver Riddle 7486fca03f0SRiver Riddle // If the references are not pointer equal, check to see if `subRef` is a 7496fca03f0SRiver Riddle // prefix of `ref`. 750c1fa60b4STres Popp if (llvm::isa<FlatSymbolRefAttr>(ref) || 7516fca03f0SRiver Riddle ref.getRootReference() != subRef.getRootReference()) 7526fca03f0SRiver Riddle return false; 7536fca03f0SRiver Riddle 7546fca03f0SRiver Riddle auto refLeafs = ref.getNestedReferences(); 7556fca03f0SRiver Riddle auto subRefLeafs = subRef.getNestedReferences(); 7566fca03f0SRiver Riddle return subRefLeafs.size() < refLeafs.size() && 7576fca03f0SRiver Riddle subRefLeafs == refLeafs.take_front(subRefLeafs.size()); 7586fca03f0SRiver Riddle } 7596fca03f0SRiver Riddle 7606fca03f0SRiver Riddle //===----------------------------------------------------------------------===// 7616fca03f0SRiver Riddle // SymbolTable::getSymbolUses 7626fca03f0SRiver Riddle 763ab9e5598SRiver Riddle /// The implementation of SymbolTable::getSymbolUses below. 764ab9e5598SRiver Riddle template <typename FromT> 765e8bcc37fSRamkumar Ramachandra static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) { 766ab9e5598SRiver Riddle std::vector<SymbolTable::SymbolUse> uses; 76701eedbc7SRiver Riddle auto walkFn = [&](SymbolTable::SymbolUse symbolUse) { 768ab9e5598SRiver Riddle uses.push_back(symbolUse); 769ab9e5598SRiver Riddle return WalkResult::advance(); 770ab9e5598SRiver Riddle }; 771ab9e5598SRiver Riddle auto result = walkSymbolUses(from, walkFn); 772e8bcc37fSRamkumar Ramachandra return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) 7731a36588eSKazu Hirata : std::nullopt; 774ab9e5598SRiver Riddle } 775ab9e5598SRiver Riddle 776b3a6ae83SRiver Riddle /// Get an iterator range for all of the uses, for any symbol, that are nested 777b3a6ae83SRiver Riddle /// within the given operation 'from'. This does not traverse into any nested 778b3a6ae83SRiver Riddle /// symbol tables, and will also only return uses on 'from' if it does not 779ef43b565SRiver Riddle /// also define a symbol table. This is because we treat the region as the 780ef43b565SRiver Riddle /// boundary of the symbol table, and not the op itself. This function returns 7814f81805aSKazu Hirata /// std::nullopt if there are any unknown operations that may potentially be 7824f81805aSKazu Hirata /// symbol tables. 783e8bcc37fSRamkumar Ramachandra auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> { 784ab9e5598SRiver Riddle return getSymbolUsesImpl(from); 785ab9e5598SRiver Riddle } 786e8bcc37fSRamkumar Ramachandra auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> { 787ab9e5598SRiver Riddle return getSymbolUsesImpl(MutableArrayRef<Region>(*from)); 7886fca03f0SRiver Riddle } 7896fca03f0SRiver Riddle 7906fca03f0SRiver Riddle //===----------------------------------------------------------------------===// 7916fca03f0SRiver Riddle // SymbolTable::getSymbolUses 7926fca03f0SRiver Riddle 7936fca03f0SRiver Riddle /// The implementation of SymbolTable::getSymbolUses below. 794ab9e5598SRiver Riddle template <typename SymbolT, typename IRUnitT> 795e8bcc37fSRamkumar Ramachandra static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, 796ab9e5598SRiver Riddle IRUnitT *limit) { 7976fca03f0SRiver Riddle std::vector<SymbolTable::SymbolUse> uses; 798ab9e5598SRiver Riddle for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { 799ab9e5598SRiver Riddle if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) { 800ab9e5598SRiver Riddle if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) 801b3a6ae83SRiver Riddle uses.push_back(symbolUse); 802ab9e5598SRiver Riddle })) 8031a36588eSKazu Hirata return std::nullopt; 804ac91e673SRiver Riddle } 805ab9e5598SRiver Riddle return SymbolTable::UseRange(std::move(uses)); 806ab9e5598SRiver Riddle } 807ac91e673SRiver Riddle 808b3a6ae83SRiver Riddle /// Get all of the uses of the given symbol that are nested within the given 809b3a6ae83SRiver Riddle /// operation 'from', invoking the provided callback for each. This does not 81070c73d1bSKazu Hirata /// traverse into any nested symbol tables. This function returns std::nullopt 81170c73d1bSKazu Hirata /// if there are any unknown operations that may potentially be symbol tables. 81241d4aa7dSChris Lattner auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from) 813e8bcc37fSRamkumar Ramachandra -> std::optional<UseRange> { 8146fca03f0SRiver Riddle return getSymbolUsesImpl(symbol, from); 8156fca03f0SRiver Riddle } 8166fca03f0SRiver Riddle auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) 817e8bcc37fSRamkumar Ramachandra -> std::optional<UseRange> { 8186fca03f0SRiver Riddle return getSymbolUsesImpl(symbol, from); 8196fca03f0SRiver Riddle } 82041d4aa7dSChris Lattner auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from) 821e8bcc37fSRamkumar Ramachandra -> std::optional<UseRange> { 822ab9e5598SRiver Riddle return getSymbolUsesImpl(symbol, from); 823ab9e5598SRiver Riddle } 824ab9e5598SRiver Riddle auto SymbolTable::getSymbolUses(Operation *symbol, Region *from) 825e8bcc37fSRamkumar Ramachandra -> std::optional<UseRange> { 826ab9e5598SRiver Riddle return getSymbolUsesImpl(symbol, from); 827ab9e5598SRiver Riddle } 828b3a6ae83SRiver Riddle 8296fca03f0SRiver Riddle //===----------------------------------------------------------------------===// 8306fca03f0SRiver Riddle // SymbolTable::symbolKnownUseEmpty 8316fca03f0SRiver Riddle 8326fca03f0SRiver Riddle /// The implementation of SymbolTable::symbolKnownUseEmpty below. 833ab9e5598SRiver Riddle template <typename SymbolT, typename IRUnitT> 834ab9e5598SRiver Riddle static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) { 835ab9e5598SRiver Riddle for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { 8366fca03f0SRiver Riddle // Walk all of the symbol uses looking for a reference to 'symbol'. 83701eedbc7SRiver Riddle if (scope.walk([&](SymbolTable::SymbolUse symbolUse) { 838ab9e5598SRiver Riddle return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) 8396fca03f0SRiver Riddle ? WalkResult::interrupt() 8406fca03f0SRiver Riddle : WalkResult::advance(); 841ab9e5598SRiver Riddle }) != WalkResult::advance()) 842ab9e5598SRiver Riddle return false; 843ab9e5598SRiver Riddle } 844ab9e5598SRiver Riddle return true; 845b3a6ae83SRiver Riddle } 846b3a6ae83SRiver Riddle 847b3a6ae83SRiver Riddle /// Return if the given symbol is known to have no uses that are nested within 848b3a6ae83SRiver Riddle /// the given operation 'from'. This does not traverse into any nested symbol 849ab9e5598SRiver Riddle /// tables. This function will also return false if there are any unknown 850ab9e5598SRiver Riddle /// operations that may potentially be symbol tables. 85141d4aa7dSChris Lattner bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) { 8526fca03f0SRiver Riddle return symbolKnownUseEmptyImpl(symbol, from); 8536fca03f0SRiver Riddle } 8546fca03f0SRiver Riddle bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { 8556fca03f0SRiver Riddle return symbolKnownUseEmptyImpl(symbol, from); 856ac91e673SRiver Riddle } 85741d4aa7dSChris Lattner bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) { 858ab9e5598SRiver Riddle return symbolKnownUseEmptyImpl(symbol, from); 859ab9e5598SRiver Riddle } 860ab9e5598SRiver Riddle bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) { 861ab9e5598SRiver Riddle return symbolKnownUseEmptyImpl(symbol, from); 862ab9e5598SRiver Riddle } 863ef43b565SRiver Riddle 8646fca03f0SRiver Riddle //===----------------------------------------------------------------------===// 8656fca03f0SRiver Riddle // SymbolTable::replaceAllSymbolUses 8666fca03f0SRiver Riddle 8676fca03f0SRiver Riddle /// Generates a new symbol reference attribute with a new leaf reference. 868df186507SBenjamin Kramer static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, 8696fca03f0SRiver Riddle FlatSymbolRefAttr newLeafAttr) { 870c1fa60b4STres Popp if (llvm::isa<FlatSymbolRefAttr>(oldAttr)) 8716fca03f0SRiver Riddle return newLeafAttr; 8726fca03f0SRiver Riddle auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); 8736fca03f0SRiver Riddle nestedRefs.back() = newLeafAttr; 87441d4aa7dSChris Lattner return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs); 8756fca03f0SRiver Riddle } 8766fca03f0SRiver Riddle 8776fca03f0SRiver Riddle /// The implementation of SymbolTable::replaceAllSymbolUses below. 878ab9e5598SRiver Riddle template <typename SymbolT, typename IRUnitT> 879ab9e5598SRiver Riddle static LogicalResult 88041d4aa7dSChris Lattner replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { 8816fca03f0SRiver Riddle // Generate a new attribute to replace the given attribute. 88241d4aa7dSChris Lattner FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol); 883ab9e5598SRiver Riddle for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { 88401eedbc7SRiver Riddle SymbolRefAttr oldAttr = scope.symbol; 885ab9e5598SRiver Riddle SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); 886e50941b8SRiver Riddle AttrTypeReplacer replacer; 887e50941b8SRiver Riddle replacer.addReplacement( 888e50941b8SRiver Riddle [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> { 88900a52c75SRiver Riddle // Regardless of the match, don't walk nested SymbolRefAttrs, we don't 89000a52c75SRiver Riddle // want to accidentally replace an inner reference. 89101eedbc7SRiver Riddle if (attr == oldAttr) 89200a52c75SRiver Riddle return {newAttr, WalkResult::skip()}; 89301eedbc7SRiver Riddle // Handle prefix matches. 894e50941b8SRiver Riddle if (isReferencePrefixOf(oldAttr, attr)) { 89501eedbc7SRiver Riddle auto oldNestedRefs = oldAttr.getNestedReferences(); 896e50941b8SRiver Riddle auto nestedRefs = attr.getNestedReferences(); 89701eedbc7SRiver Riddle if (oldNestedRefs.empty()) 89800a52c75SRiver Riddle return {SymbolRefAttr::get(newSymbol, nestedRefs), 89900a52c75SRiver Riddle WalkResult::skip()}; 90001eedbc7SRiver Riddle 90101eedbc7SRiver Riddle auto newNestedRefs = llvm::to_vector<4>(nestedRefs); 90201eedbc7SRiver Riddle newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; 903e50941b8SRiver Riddle return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs), 90400a52c75SRiver Riddle WalkResult::skip()}; 9056fca03f0SRiver Riddle } 90600a52c75SRiver Riddle return {attr, WalkResult::skip()}; 907e50941b8SRiver Riddle }); 9086fca03f0SRiver Riddle 9090a81ace0SKazu Hirata auto walkFn = [&](Operation *op) -> std::optional<WalkResult> { 910e50941b8SRiver Riddle replacer.replaceElementsIn(op); 9116fca03f0SRiver Riddle return WalkResult::advance(); 9126fca03f0SRiver Riddle }; 91301eedbc7SRiver Riddle if (!scope.walkSymbolTable(walkFn)) 914ab9e5598SRiver Riddle return failure(); 9156fca03f0SRiver Riddle } 9166fca03f0SRiver Riddle return success(); 9176fca03f0SRiver Riddle } 9186fca03f0SRiver Riddle 919ef43b565SRiver Riddle /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 920ef43b565SRiver Riddle /// provided symbol 'newSymbol' that are nested within the given operation 921ab9e5598SRiver Riddle /// 'from'. This does not traverse into any nested symbol tables. If there are 922ab9e5598SRiver Riddle /// any unknown operations that may potentially be symbol tables, no uses are 923ab9e5598SRiver Riddle /// replaced and failure is returned. 92441d4aa7dSChris Lattner LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, 92541d4aa7dSChris Lattner StringAttr newSymbol, 926ef43b565SRiver Riddle Operation *from) { 9276fca03f0SRiver Riddle return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); 928ef43b565SRiver Riddle } 9296fca03f0SRiver Riddle LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, 93041d4aa7dSChris Lattner StringAttr newSymbol, 9316fca03f0SRiver Riddle Operation *from) { 9326fca03f0SRiver Riddle return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); 933ef43b565SRiver Riddle } 93441d4aa7dSChris Lattner LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, 93541d4aa7dSChris Lattner StringAttr newSymbol, 936ab9e5598SRiver Riddle Region *from) { 937ab9e5598SRiver Riddle return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); 938ab9e5598SRiver Riddle } 939ab9e5598SRiver Riddle LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, 94041d4aa7dSChris Lattner StringAttr newSymbol, 941ab9e5598SRiver Riddle Region *from) { 942ab9e5598SRiver Riddle return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); 943ab9e5598SRiver Riddle } 9447c221a7dSRiver Riddle 9457c221a7dSRiver Riddle //===----------------------------------------------------------------------===// 9467bc7d0acSRiver Riddle // SymbolTableCollection 9477bc7d0acSRiver Riddle //===----------------------------------------------------------------------===// 9487bc7d0acSRiver Riddle 9497bc7d0acSRiver Riddle Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, 95041d4aa7dSChris Lattner StringAttr symbol) { 9517bc7d0acSRiver Riddle return getSymbolTable(symbolTableOp).lookup(symbol); 9527bc7d0acSRiver Riddle } 9537bc7d0acSRiver Riddle Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, 9547bc7d0acSRiver Riddle SymbolRefAttr name) { 9557bc7d0acSRiver Riddle SmallVector<Operation *, 4> symbols; 9567bc7d0acSRiver Riddle if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) 9577bc7d0acSRiver Riddle return nullptr; 9587bc7d0acSRiver Riddle return symbols.back(); 9597bc7d0acSRiver Riddle } 9607bc7d0acSRiver Riddle /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by 9617bc7d0acSRiver Riddle /// a given SymbolRefAttr. Returns failure if any of the nested references could 9627bc7d0acSRiver Riddle /// not be resolved. 9637bc7d0acSRiver Riddle LogicalResult 9647bc7d0acSRiver Riddle SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, 9657bc7d0acSRiver Riddle SymbolRefAttr name, 9667bc7d0acSRiver Riddle SmallVectorImpl<Operation *> &symbols) { 96741d4aa7dSChris Lattner auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { 9687bc7d0acSRiver Riddle return lookupSymbolIn(symbolTableOp, symbol); 9697bc7d0acSRiver Riddle }; 9707bc7d0acSRiver Riddle return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); 9717bc7d0acSRiver Riddle } 9727bc7d0acSRiver Riddle 97371eeb5ecSRiver Riddle /// Returns the operation registered with the given symbol name within the 97471eeb5ecSRiver Riddle /// closest parent operation of, or including, 'from' with the 97571eeb5ecSRiver Riddle /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 97671eeb5ecSRiver Riddle /// found. 97771eeb5ecSRiver Riddle Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, 97841d4aa7dSChris Lattner StringAttr symbol) { 97971eeb5ecSRiver Riddle Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); 98071eeb5ecSRiver Riddle return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; 98171eeb5ecSRiver Riddle } 98271eeb5ecSRiver Riddle Operation * 98371eeb5ecSRiver Riddle SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, 98471eeb5ecSRiver Riddle SymbolRefAttr symbol) { 98571eeb5ecSRiver Riddle Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); 98671eeb5ecSRiver Riddle return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; 98771eeb5ecSRiver Riddle } 98871eeb5ecSRiver Riddle 9897bc7d0acSRiver Riddle /// Lookup, or create, a symbol table for an operation. 9907bc7d0acSRiver Riddle SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) { 9917bc7d0acSRiver Riddle auto it = symbolTables.try_emplace(op, nullptr); 9927bc7d0acSRiver Riddle if (it.second) 9937bc7d0acSRiver Riddle it.first->second = std::make_unique<SymbolTable>(op); 9947bc7d0acSRiver Riddle return *it.first->second; 9957bc7d0acSRiver Riddle } 9967bc7d0acSRiver Riddle 9977bc7d0acSRiver Riddle //===----------------------------------------------------------------------===// 9982cfc66a6SJeff Niu // LockedSymbolTableCollection 9992cfc66a6SJeff Niu //===----------------------------------------------------------------------===// 10002cfc66a6SJeff Niu 10012cfc66a6SJeff Niu Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, 10022cfc66a6SJeff Niu StringAttr symbol) { 10032cfc66a6SJeff Niu return getSymbolTable(symbolTableOp).lookup(symbol); 10042cfc66a6SJeff Niu } 10052cfc66a6SJeff Niu 10062cfc66a6SJeff Niu Operation * 10072cfc66a6SJeff Niu LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, 10082cfc66a6SJeff Niu FlatSymbolRefAttr symbol) { 10092cfc66a6SJeff Niu return lookupSymbolIn(symbolTableOp, symbol.getAttr()); 10102cfc66a6SJeff Niu } 10112cfc66a6SJeff Niu 10122cfc66a6SJeff Niu Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, 10132cfc66a6SJeff Niu SymbolRefAttr name) { 10142cfc66a6SJeff Niu SmallVector<Operation *> symbols; 10152cfc66a6SJeff Niu if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) 10162cfc66a6SJeff Niu return nullptr; 10172cfc66a6SJeff Niu return symbols.back(); 10182cfc66a6SJeff Niu } 10192cfc66a6SJeff Niu 10202cfc66a6SJeff Niu LogicalResult LockedSymbolTableCollection::lookupSymbolIn( 10212cfc66a6SJeff Niu Operation *symbolTableOp, SymbolRefAttr name, 10222cfc66a6SJeff Niu SmallVectorImpl<Operation *> &symbols) { 10232cfc66a6SJeff Niu auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { 10242cfc66a6SJeff Niu return lookupSymbolIn(symbolTableOp, symbol); 10252cfc66a6SJeff Niu }; 10262cfc66a6SJeff Niu return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); 10272cfc66a6SJeff Niu } 10282cfc66a6SJeff Niu 10292cfc66a6SJeff Niu SymbolTable & 103055bc18a7SJeff Niu LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) { 10312cfc66a6SJeff Niu assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); 10322cfc66a6SJeff Niu // Try to find an existing symbol table. 10332cfc66a6SJeff Niu { 10342cfc66a6SJeff Niu llvm::sys::SmartScopedReader<true> lock(mutex); 10352cfc66a6SJeff Niu auto it = collection.symbolTables.find(symbolTableOp); 10362cfc66a6SJeff Niu if (it != collection.symbolTables.end()) 10372cfc66a6SJeff Niu return *it->second; 10382cfc66a6SJeff Niu } 10392cfc66a6SJeff Niu // Create a symbol table for the operation. Perform construction outside of 10402cfc66a6SJeff Niu // the critical section. 10412cfc66a6SJeff Niu auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp); 10422cfc66a6SJeff Niu // Insert the constructed symbol table. 10432cfc66a6SJeff Niu llvm::sys::SmartScopedWriter<true> lock(mutex); 10442cfc66a6SJeff Niu return *collection.symbolTables 10452cfc66a6SJeff Niu .insert({symbolTableOp, std::move(symbolTable)}) 10462cfc66a6SJeff Niu .first->second; 10472cfc66a6SJeff Niu } 10482cfc66a6SJeff Niu 10492cfc66a6SJeff Niu //===----------------------------------------------------------------------===// 10504a7aed4eSRiver Riddle // SymbolUserMap 10514a7aed4eSRiver Riddle //===----------------------------------------------------------------------===// 10524a7aed4eSRiver Riddle 10534a7aed4eSRiver Riddle SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable, 10544a7aed4eSRiver Riddle Operation *symbolTableOp) 10554a7aed4eSRiver Riddle : symbolTable(symbolTable) { 10564a7aed4eSRiver Riddle // Walk each of the symbol tables looking for discardable callgraph nodes. 10574a7aed4eSRiver Riddle SmallVector<Operation *> symbols; 10584a7aed4eSRiver Riddle auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 10594a7aed4eSRiver Riddle for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) { 10604a7aed4eSRiver Riddle auto symbolUses = SymbolTable::getSymbolUses(&nestedOp); 10614a7aed4eSRiver Riddle assert(symbolUses && "expected uses to be valid"); 10624a7aed4eSRiver Riddle 10634a7aed4eSRiver Riddle for (const SymbolTable::SymbolUse &use : *symbolUses) { 10644a7aed4eSRiver Riddle symbols.clear(); 10654a7aed4eSRiver Riddle (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(), 10664a7aed4eSRiver Riddle symbols); 10674a7aed4eSRiver Riddle for (Operation *symbolOp : symbols) 10684a7aed4eSRiver Riddle symbolToUsers[symbolOp].insert(use.getUser()); 10694a7aed4eSRiver Riddle } 10704a7aed4eSRiver Riddle } 10714a7aed4eSRiver Riddle }; 10724a7aed4eSRiver Riddle // We just set `allSymUsesVisible` to false here because it isn't necessary 10734a7aed4eSRiver Riddle // for building the user map. 10744a7aed4eSRiver Riddle SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false, 10754a7aed4eSRiver Riddle walkFn); 10764a7aed4eSRiver Riddle } 10774a7aed4eSRiver Riddle 10784a7aed4eSRiver Riddle void SymbolUserMap::replaceAllUsesWith(Operation *symbol, 107941d4aa7dSChris Lattner StringAttr newSymbolName) { 10804a7aed4eSRiver Riddle auto it = symbolToUsers.find(symbol); 10814a7aed4eSRiver Riddle if (it == symbolToUsers.end()) 10824a7aed4eSRiver Riddle return; 10834a7aed4eSRiver Riddle 10844a7aed4eSRiver Riddle // Replace the uses within the users of `symbol`. 1085f92d319cSNandor Licker for (Operation *user : it->second) 10864a7aed4eSRiver Riddle (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user); 10874a7aed4eSRiver Riddle 10884a7aed4eSRiver Riddle // Move the current users of `symbol` to the new symbol if it is in the 10894a7aed4eSRiver Riddle // symbol table. 10904a7aed4eSRiver Riddle Operation *newSymbol = 10914a7aed4eSRiver Riddle symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName); 10924a7aed4eSRiver Riddle if (newSymbol != symbol) { 1093f92d319cSNandor Licker // Transfer over the users to the new symbol. The reference to the old one 1094f92d319cSNandor Licker // is fetched again as the iterator is invalidated during the insertion. 1095f92d319cSNandor Licker auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{}); 1096f92d319cSNandor Licker auto oldIt = symbolToUsers.find(symbol); 1097f92d319cSNandor Licker assert(oldIt != symbolToUsers.end() && "missing old users list"); 1098f92d319cSNandor Licker if (newIt.second) 1099f92d319cSNandor Licker newIt.first->second = std::move(oldIt->second); 11004a7aed4eSRiver Riddle else 1101f92d319cSNandor Licker newIt.first->second.set_union(oldIt->second); 1102f92d319cSNandor Licker symbolToUsers.erase(oldIt); 11034a7aed4eSRiver Riddle } 11044a7aed4eSRiver Riddle } 11054a7aed4eSRiver Riddle 11064a7aed4eSRiver Riddle //===----------------------------------------------------------------------===// 11078b5a3e46SRahul Joshi // Visibility parsing implementation. 11088b5a3e46SRahul Joshi //===----------------------------------------------------------------------===// 11098b5a3e46SRahul Joshi 11108b5a3e46SRahul Joshi ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser, 11118b5a3e46SRahul Joshi NamedAttrList &attrs) { 11128b5a3e46SRahul Joshi StringRef visibility; 11138b5a3e46SRahul Joshi if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"})) 11148b5a3e46SRahul Joshi return failure(); 11158b5a3e46SRahul Joshi 11168b5a3e46SRahul Joshi StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility); 11178b5a3e46SRahul Joshi attrs.push_back(parser.getBuilder().getNamedAttr( 11188b5a3e46SRahul Joshi SymbolTable::getVisibilityAttrName(), visibilityAttr)); 11198b5a3e46SRahul Joshi return success(); 11208b5a3e46SRahul Joshi } 11218b5a3e46SRahul Joshi 11228b5a3e46SRahul Joshi //===----------------------------------------------------------------------===// 11237c221a7dSRiver Riddle // Symbol Interfaces 11247c221a7dSRiver Riddle //===----------------------------------------------------------------------===// 11257c221a7dSRiver Riddle 11267c221a7dSRiver Riddle /// Include the generated symbol interfaces. 11277c221a7dSRiver Riddle #include "mlir/IR/SymbolInterfaces.cpp.inc" 1128