xref: /llvm-project/mlir/test/lib/IR/TestSymbolUses.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestOps.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// This is a symbol test pass that tests the symbol uselist functionality
17 /// provided by the symbol table along with erasing from the symbol table.
18 struct SymbolUsesPass
19     : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anond7b0b8eb0111::SymbolUsesPass20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SymbolUsesPass)
21 
22   StringRef getArgument() const final { return "test-symbol-uses"; }
getDescription__anond7b0b8eb0111::SymbolUsesPass23   StringRef getDescription() const final {
24     return "Test detection of symbol uses";
25   }
operateOnSymbol__anond7b0b8eb0111::SymbolUsesPass26   WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
27                              SmallVectorImpl<func::FuncOp> &deadFunctions) {
28     // Test computing uses on a non symboltable op.
29     std::optional<SymbolTable::UseRange> symbolUses =
30         SymbolTable::getSymbolUses(symbol);
31 
32     // Test the conservative failure case.
33     if (!symbolUses) {
34       symbol->emitRemark()
35           << "symbol contains an unknown nested operation that "
36              "'may' define a new symbol table";
37       return WalkResult::interrupt();
38     }
39     if (unsigned numUses = llvm::size(*symbolUses))
40       symbol->emitRemark() << "symbol contains " << numUses
41                            << " nested references";
42 
43     // Test the functionality of symbolKnownUseEmpty.
44     if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) {
45       func::FuncOp funcSymbol = dyn_cast<func::FuncOp>(symbol);
46       if (funcSymbol && funcSymbol.isExternal())
47         deadFunctions.push_back(funcSymbol);
48 
49       symbol->emitRemark() << "symbol has no uses";
50       return WalkResult::advance();
51     }
52 
53     // Test the functionality of getSymbolUses.
54     symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion());
55     assert(symbolUses && "expected no unknown operations");
56     for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
57       // Check that we can resolve back to our symbol.
58       if (SymbolTable::lookupNearestSymbolFrom(
59               symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) {
60         symbolUse.getUser()->emitRemark()
61             << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
62             << *symbol->getInherentAttr(SymbolTable::getSymbolAttrName());
63       }
64     }
65     symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";
66     return WalkResult::advance();
67   }
68 
runOnOperation__anond7b0b8eb0111::SymbolUsesPass69   void runOnOperation() override {
70     auto module = getOperation();
71 
72     // Walk nested symbols.
73     SmallVector<func::FuncOp, 4> deadFunctions;
74     module.getBodyRegion().walk([&](Operation *nestedOp) {
75       if (isa<SymbolOpInterface>(nestedOp))
76         return operateOnSymbol(nestedOp, module, deadFunctions);
77       return WalkResult::advance();
78     });
79 
80     SymbolTable table(module);
81     for (Operation *op : deadFunctions) {
82       // In order to test the SymbolTable::erase method, also erase completely
83       // useless functions.
84       auto name = SymbolTable::getSymbolName(op);
85       assert(table.lookup(name) && "expected no unknown operations");
86       table.erase(op);
87       assert(!table.lookup(name) &&
88              "expected erased operation to be unknown now");
89       module.emitRemark() << name.getValue() << " function successfully erased";
90     }
91   }
92 };
93 
94 /// This is a symbol test pass that tests the symbol use replacement
95 /// functionality provided by the symbol table.
96 struct SymbolReplacementPass
97     : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anond7b0b8eb0111::SymbolReplacementPass98   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SymbolReplacementPass)
99 
100   StringRef getArgument() const final { return "test-symbol-rauw"; }
getDescription__anond7b0b8eb0111::SymbolReplacementPass101   StringRef getDescription() const final {
102     return "Test replacement of symbol uses";
103   }
runOnOperation__anond7b0b8eb0111::SymbolReplacementPass104   void runOnOperation() override {
105     ModuleOp module = getOperation();
106 
107     // Don't try to replace if we can't collect symbol uses.
108     if (!SymbolTable::getSymbolUses(&module.getBodyRegion()))
109       return;
110 
111     SymbolTableCollection symbolTable;
112     SymbolUserMap symbolUsers(symbolTable, module);
113     module.getBodyRegion().walk([&](Operation *nestedOp) {
114       StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
115       if (!newName)
116         return;
117       symbolUsers.replaceAllUsesWith(nestedOp, newName);
118       SymbolTable::setSymbolName(nestedOp, newName);
119     });
120   }
121 };
122 } // namespace
123 
124 namespace mlir {
registerSymbolTestPasses()125 void registerSymbolTestPasses() {
126   PassRegistration<SymbolUsesPass>();
127 
128   PassRegistration<SymbolReplacementPass>();
129 }
130 } // namespace mlir
131