xref: /llvm-project/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (revision 13317502da8ee3885854f67700140586c0edafee)
1 //===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12 #include "mlir/Analysis/DataFlowFramework.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Block.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Location.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Value.h"
20 #include "mlir/IR/ValueRange.h"
21 #include "mlir/Interfaces/CallInterfaces.h"
22 #include "mlir/Interfaces/ControlFlowInterfaces.h"
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/Support/Casting.h"
25 #include <cassert>
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::dataflow;
30 
31 //===----------------------------------------------------------------------===//
32 // Executable
33 //===----------------------------------------------------------------------===//
34 
35 ChangeResult Executable::setToLive() {
36   if (live)
37     return ChangeResult::NoChange;
38   live = true;
39   return ChangeResult::Change;
40 }
41 
42 void Executable::print(raw_ostream &os) const {
43   os << (live ? "live" : "dead");
44 }
45 
46 void Executable::onUpdate(DataFlowSolver *solver) const {
47   AnalysisState::onUpdate(solver);
48 
49   if (ProgramPoint *pp = llvm::dyn_cast_if_present<ProgramPoint *>(anchor)) {
50     if (pp->isBlockStart()) {
51       // Re-invoke the analyses on the block itself.
52       for (DataFlowAnalysis *analysis : subscribers)
53         solver->enqueue({pp, analysis});
54       // Re-invoke the analyses on all operations in the block.
55       for (DataFlowAnalysis *analysis : subscribers)
56         for (Operation &op : *pp->getBlock())
57           solver->enqueue({solver->getProgramPointAfter(&op), analysis});
58     }
59   } else if (auto *latticeAnchor =
60                  llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) {
61     // Re-invoke the analysis on the successor block.
62     if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) {
63       for (DataFlowAnalysis *analysis : subscribers)
64         solver->enqueue(
65             {solver->getProgramPointBefore(edge->getTo()), analysis});
66     }
67   }
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // PredecessorState
72 //===----------------------------------------------------------------------===//
73 
74 void PredecessorState::print(raw_ostream &os) const {
75   if (allPredecessorsKnown())
76     os << "(all) ";
77   os << "predecessors:\n";
78   for (Operation *op : getKnownPredecessors())
79     os << "  " << *op << "\n";
80 }
81 
82 ChangeResult PredecessorState::join(Operation *predecessor) {
83   return knownPredecessors.insert(predecessor) ? ChangeResult::Change
84                                                : ChangeResult::NoChange;
85 }
86 
87 ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
88   ChangeResult result = join(predecessor);
89   if (!inputs.empty()) {
90     ValueRange &curInputs = successorInputs[predecessor];
91     if (curInputs != inputs) {
92       curInputs = inputs;
93       result |= ChangeResult::Change;
94     }
95   }
96   return result;
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // CFGEdge
101 //===----------------------------------------------------------------------===//
102 
103 Location CFGEdge::getLoc() const {
104   return FusedLoc::get(
105       getFrom()->getParent()->getContext(),
106       {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
107 }
108 
109 void CFGEdge::print(raw_ostream &os) const {
110   getFrom()->print(os);
111   os << "\n -> \n";
112   getTo()->print(os);
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // DeadCodeAnalysis
117 //===----------------------------------------------------------------------===//
118 
119 DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
120     : DataFlowAnalysis(solver) {
121   registerAnchorKind<CFGEdge>();
122 }
123 
124 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
125   // Mark the top-level blocks as executable.
126   for (Region &region : top->getRegions()) {
127     if (region.empty())
128       continue;
129     auto *state =
130         getOrCreate<Executable>(getProgramPointBefore(&region.front()));
131     propagateIfChanged(state, state->setToLive());
132   }
133 
134   // Mark as overdefined the predecessors of symbol callables with potentially
135   // unknown predecessors.
136   initializeSymbolCallables(top);
137 
138   return initializeRecursively(top);
139 }
140 
141 void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
142   analysisScope = top;
143   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
144     Region &symbolTableRegion = symTable->getRegion(0);
145     Block *symbolTableBlock = &symbolTableRegion.front();
146 
147     bool foundSymbolCallable = false;
148     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
149       Region *callableRegion = callable.getCallableRegion();
150       if (!callableRegion)
151         continue;
152       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
153       if (!symbol)
154         continue;
155 
156       // Public symbol callables or those for which we can't see all uses have
157       // potentially unknown callsites.
158       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
159         auto *state =
160             getOrCreate<PredecessorState>(getProgramPointAfter(callable));
161         propagateIfChanged(state, state->setHasUnknownPredecessors());
162       }
163       foundSymbolCallable = true;
164     }
165 
166     // Exit early if no eligible symbol callables were found in the table.
167     if (!foundSymbolCallable)
168       return;
169 
170     // Walk the symbol table to check for non-call uses of symbols.
171     std::optional<SymbolTable::UseRange> uses =
172         SymbolTable::getSymbolUses(&symbolTableRegion);
173     if (!uses) {
174       // If we couldn't gather the symbol uses, conservatively assume that
175       // we can't track information for any nested symbols.
176       return top->walk([&](CallableOpInterface callable) {
177         auto *state =
178             getOrCreate<PredecessorState>(getProgramPointAfter(callable));
179         propagateIfChanged(state, state->setHasUnknownPredecessors());
180       });
181     }
182 
183     for (const SymbolTable::SymbolUse &use : *uses) {
184       if (isa<CallOpInterface>(use.getUser()))
185         continue;
186       // If a callable symbol has a non-call use, then we can't be guaranteed to
187       // know all callsites.
188       Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
189       if (!symbol)
190         continue;
191       auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
192       propagateIfChanged(state, state->setHasUnknownPredecessors());
193     }
194   };
195   SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
196                                 walkFn);
197 }
198 
199 /// Returns true if the operation is a returning terminator in region
200 /// control-flow or the terminator of a callable region.
201 static bool isRegionOrCallableReturn(Operation *op) {
202   return op->getBlock() != nullptr && !op->getNumSuccessors() &&
203          isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) &&
204          op->getBlock()->getTerminator() == op;
205 }
206 
207 LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
208   // Initialize the analysis by visiting every op with control-flow semantics.
209   if (op->getNumRegions() || op->getNumSuccessors() ||
210       isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
211     // When the liveness of the parent block changes, make sure to re-invoke the
212     // analysis on the op.
213     if (op->getBlock())
214       getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
215           ->blockContentSubscribe(this);
216     // Visit the op.
217     if (failed(visit(getProgramPointAfter(op))))
218       return failure();
219   }
220   // Recurse on nested operations.
221   for (Region &region : op->getRegions())
222     for (Operation &op : region.getOps())
223       if (failed(initializeRecursively(&op)))
224         return failure();
225   return success();
226 }
227 
228 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
229   auto *state = getOrCreate<Executable>(getProgramPointBefore(to));
230   propagateIfChanged(state, state->setToLive());
231   auto *edgeState =
232       getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to));
233   propagateIfChanged(edgeState, edgeState->setToLive());
234 }
235 
236 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
237   for (Region &region : op->getRegions()) {
238     if (region.empty())
239       continue;
240     auto *state =
241         getOrCreate<Executable>(getProgramPointBefore(&region.front()));
242     propagateIfChanged(state, state->setToLive());
243   }
244 }
245 
246 LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
247   if (point->isBlockStart())
248     return success();
249   Operation *op = point->getPrevOp();
250 
251   // If the parent block is not executable, there is nothing to do.
252   if (op->getBlock() != nullptr &&
253       !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
254     return success();
255 
256   // We have a live call op. Add this as a live predecessor of the callee.
257   if (auto call = dyn_cast<CallOpInterface>(op))
258     visitCallOperation(call);
259 
260   // Visit the regions.
261   if (op->getNumRegions()) {
262     // Check if we can reason about the region control-flow.
263     if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
264       visitRegionBranchOperation(branch);
265 
266       // Check if this is a callable operation.
267     } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
268       const auto *callsites = getOrCreateFor<PredecessorState>(
269           getProgramPointAfter(op), getProgramPointAfter(callable));
270 
271       // If the callsites could not be resolved or are known to be non-empty,
272       // mark the callable as executable.
273       if (!callsites->allPredecessorsKnown() ||
274           !callsites->getKnownPredecessors().empty())
275         markEntryBlocksLive(callable);
276 
277       // Otherwise, conservatively mark all entry blocks as executable.
278     } else {
279       markEntryBlocksLive(op);
280     }
281   }
282 
283   if (isRegionOrCallableReturn(op)) {
284     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
285       // Visit the exiting terminator of a region.
286       visitRegionTerminator(op, branch);
287     } else if (auto callable =
288                    dyn_cast<CallableOpInterface>(op->getParentOp())) {
289       // Visit the exiting terminator of a callable.
290       visitCallableTerminator(op, callable);
291     }
292   }
293   // Visit the successors.
294   if (op->getNumSuccessors()) {
295     // Check if we can reason about the control-flow.
296     if (auto branch = dyn_cast<BranchOpInterface>(op)) {
297       visitBranchOperation(branch);
298 
299       // Otherwise, conservatively mark all successors as exectuable.
300     } else {
301       for (Block *successor : op->getSuccessors())
302         markEdgeLive(op->getBlock(), successor);
303     }
304   }
305 
306   return success();
307 }
308 
309 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
310   Operation *callableOp = call.resolveCallableInTable(&symbolTable);
311 
312   // A call to a externally-defined callable has unknown predecessors.
313   const auto isExternalCallable = [this](Operation *op) {
314     // A callable outside the analysis scope is an external callable.
315     if (!analysisScope->isAncestor(op))
316       return true;
317     // Otherwise, check if the callable region is defined.
318     if (auto callable = dyn_cast<CallableOpInterface>(op))
319       return !callable.getCallableRegion();
320     return false;
321   };
322 
323   // TODO: Add support for non-symbol callables when necessary. If the
324   // callable has non-call uses we would mark as having reached pessimistic
325   // fixpoint, otherwise allow for propagating the return values out.
326   if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
327       !isExternalCallable(callableOp)) {
328     // Add the live callsite.
329     auto *callsites =
330         getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
331     propagateIfChanged(callsites, callsites->join(call));
332   } else {
333     // Mark this call op's predecessors as overdefined.
334     auto *predecessors =
335         getOrCreate<PredecessorState>(getProgramPointAfter(call));
336     propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
337   }
338 }
339 
340 /// Get the constant values of the operands of an operation. If any of the
341 /// constant value lattices are uninitialized, return std::nullopt to indicate
342 /// the analysis should bail out.
343 static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
344     Operation *op,
345     function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
346   SmallVector<Attribute> operands;
347   operands.reserve(op->getNumOperands());
348   for (Value operand : op->getOperands()) {
349     const Lattice<ConstantValue> *cv = getLattice(operand);
350     // If any of the operands' values are uninitialized, bail out.
351     if (cv->getValue().isUninitialized())
352       return {};
353     operands.push_back(cv->getValue().getConstantValue());
354   }
355   return operands;
356 }
357 
358 std::optional<SmallVector<Attribute>>
359 DeadCodeAnalysis::getOperandValues(Operation *op) {
360   return getOperandValuesImpl(op, [&](Value value) {
361     auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
362     lattice->useDefSubscribe(this);
363     return lattice;
364   });
365 }
366 
367 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
368   // Try to deduce a single successor for the branch.
369   std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
370   if (!operands)
371     return;
372 
373   if (Block *successor = branch.getSuccessorForOperands(*operands)) {
374     markEdgeLive(branch->getBlock(), successor);
375   } else {
376     // Otherwise, mark all successors as executable and outgoing edges.
377     for (Block *successor : branch->getSuccessors())
378       markEdgeLive(branch->getBlock(), successor);
379   }
380 }
381 
382 void DeadCodeAnalysis::visitRegionBranchOperation(
383     RegionBranchOpInterface branch) {
384   // Try to deduce which regions are executable.
385   std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
386   if (!operands)
387     return;
388 
389   SmallVector<RegionSuccessor> successors;
390   branch.getEntrySuccessorRegions(*operands, successors);
391   for (const RegionSuccessor &successor : successors) {
392     // The successor can be either an entry block or the parent operation.
393     ProgramPoint *point =
394         successor.getSuccessor()
395             ? getProgramPointBefore(&successor.getSuccessor()->front())
396             : getProgramPointAfter(branch);
397     // Mark the entry block as executable.
398     auto *state = getOrCreate<Executable>(point);
399     propagateIfChanged(state, state->setToLive());
400     // Add the parent op as a predecessor.
401     auto *predecessors = getOrCreate<PredecessorState>(point);
402     propagateIfChanged(
403         predecessors,
404         predecessors->join(branch, successor.getSuccessorInputs()));
405   }
406 }
407 
408 void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
409                                              RegionBranchOpInterface branch) {
410   std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
411   if (!operands)
412     return;
413 
414   SmallVector<RegionSuccessor> successors;
415   if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op))
416     terminator.getSuccessorRegions(*operands, successors);
417   else
418     branch.getSuccessorRegions(op->getParentRegion(), successors);
419 
420   // Mark successor region entry blocks as executable and add this op to the
421   // list of predecessors.
422   for (const RegionSuccessor &successor : successors) {
423     PredecessorState *predecessors;
424     if (Region *region = successor.getSuccessor()) {
425       auto *state =
426           getOrCreate<Executable>(getProgramPointBefore(&region->front()));
427       propagateIfChanged(state, state->setToLive());
428       predecessors = getOrCreate<PredecessorState>(
429           getProgramPointBefore(&region->front()));
430     } else {
431       // Add this terminator as a predecessor to the parent op.
432       predecessors =
433           getOrCreate<PredecessorState>(getProgramPointAfter(branch));
434     }
435     propagateIfChanged(predecessors,
436                        predecessors->join(op, successor.getSuccessorInputs()));
437   }
438 }
439 
440 void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
441                                                CallableOpInterface callable) {
442   // Add as predecessors to all callsites this return op.
443   auto *callsites = getOrCreateFor<PredecessorState>(
444       getProgramPointAfter(op), getProgramPointAfter(callable));
445   bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
446   for (Operation *predecessor : callsites->getKnownPredecessors()) {
447     assert(isa<CallOpInterface>(predecessor));
448     auto *predecessors =
449         getOrCreate<PredecessorState>(getProgramPointAfter(predecessor));
450     if (canResolve) {
451       propagateIfChanged(predecessors, predecessors->join(op));
452     } else {
453       // If the terminator is not a return-like, then conservatively assume we
454       // can't resolve the predecessor.
455       propagateIfChanged(predecessors,
456                          predecessors->setHasUnknownPredecessors());
457     }
458   }
459 }
460