xref: /llvm-project/mlir/lib/Transforms/SymbolDCE.cpp (revision e8bcc37fff5bda7dd9326903a2c31e6703b4fe68)
1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/IR/SymbolTable.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_SYMBOLDCE
20 #include "mlir/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 namespace {
26 struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
27   void runOnOperation() override;
28 
29   /// Compute the liveness of the symbols within the given symbol table.
30   /// `symbolTableIsHidden` is true if this symbol table is known to be
31   /// unaccessible from operations in its parent regions.
32   LogicalResult computeLiveness(Operation *symbolTableOp,
33                                 SymbolTableCollection &symbolTable,
34                                 bool symbolTableIsHidden,
35                                 DenseSet<Operation *> &liveSymbols);
36 };
37 } // namespace
38 
runOnOperation()39 void SymbolDCE::runOnOperation() {
40   Operation *symbolTableOp = getOperation();
41 
42   // SymbolDCE should only be run on operations that define a symbol table.
43   if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
44     symbolTableOp->emitOpError()
45         << " was scheduled to run under SymbolDCE, but does not define a "
46            "symbol table";
47     return signalPassFailure();
48   }
49 
50   // A flag that signals if the top level symbol table is hidden, i.e. not
51   // accessible from parent scopes.
52   bool symbolTableIsHidden = true;
53   SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
54   if (symbolTableOp->getParentOp() && symbol)
55     symbolTableIsHidden = symbol.isPrivate();
56 
57   // Compute the set of live symbols within the symbol table.
58   DenseSet<Operation *> liveSymbols;
59   SymbolTableCollection symbolTable;
60   if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
61                              liveSymbols)))
62     return signalPassFailure();
63 
64   // After computing the liveness, delete all of the symbols that were found to
65   // be dead.
66   symbolTableOp->walk([&](Operation *nestedSymbolTable) {
67     if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
68       return;
69     for (auto &block : nestedSymbolTable->getRegion(0)) {
70       for (Operation &op : llvm::make_early_inc_range(block)) {
71         if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
72           op.erase();
73           ++numDCE;
74         }
75       }
76     }
77   });
78 }
79 
80 /// Compute the liveness of the symbols within the given symbol table.
81 /// `symbolTableIsHidden` is true if this symbol table is known to be
82 /// unaccessible from operations in its parent regions.
computeLiveness(Operation * symbolTableOp,SymbolTableCollection & symbolTable,bool symbolTableIsHidden,DenseSet<Operation * > & liveSymbols)83 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
84                                          SymbolTableCollection &symbolTable,
85                                          bool symbolTableIsHidden,
86                                          DenseSet<Operation *> &liveSymbols) {
87   // A worklist of live operations to propagate uses from.
88   SmallVector<Operation *, 16> worklist;
89 
90   // Walk the symbols within the current symbol table, marking the symbols that
91   // are known to be live.
92   for (auto &block : symbolTableOp->getRegion(0)) {
93     // Add all non-symbols or symbols that can't be discarded.
94     for (Operation &op : block) {
95       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
96       if (!symbol) {
97         worklist.push_back(&op);
98         continue;
99       }
100       bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
101                            symbol.canDiscardOnUseEmpty();
102       if (!isDiscardable && liveSymbols.insert(&op).second)
103         worklist.push_back(&op);
104     }
105   }
106 
107   // Process the set of symbols that were known to be live, adding new symbols
108   // that are referenced within.
109   while (!worklist.empty()) {
110     Operation *op = worklist.pop_back_val();
111 
112     // If this is a symbol table, recursively compute its liveness.
113     if (op->hasTrait<OpTrait::SymbolTable>()) {
114       // The internal symbol table is hidden if the parent is, if its not a
115       // symbol, or if it is a private symbol.
116       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117       bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
118       if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
119         return failure();
120     }
121 
122     // Collect the uses held by this operation.
123     std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
124     if (!uses) {
125       return op->emitError()
126              << "operation contains potentially unknown symbol table, "
127                 "meaning that we can't reliable compute symbol uses";
128     }
129 
130     SmallVector<Operation *, 4> resolvedSymbols;
131     for (const SymbolTable::SymbolUse &use : *uses) {
132       // Lookup the symbols referenced by this use.
133       resolvedSymbols.clear();
134       if (failed(symbolTable.lookupSymbolIn(
135               op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
136         // Ignore references to unknown symbols.
137         continue;
138 
139       // Mark each of the resolved symbols as live.
140       for (Operation *resolvedSymbol : resolvedSymbols)
141         if (liveSymbols.insert(resolvedSymbol).second)
142           worklist.push_back(resolvedSymbol);
143     }
144   }
145 
146   return success();
147 }
148 
createSymbolDCEPass()149 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
150   return std::make_unique<SymbolDCE>();
151 }
152