xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1 //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 /// Print the liveness of every block, control-flow edge, and the predecessors
18 /// of all regions, callables, and calls.
19 static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
20                                  raw_ostream &os) {
21   op->walk([&](Operation *op) {
22     auto tag = op->getAttrOfType<StringAttr>("tag");
23     if (!tag)
24       return;
25     os << tag.getValue() << ":\n";
26     for (Region &region : op->getRegions()) {
27       os << " region #" << region.getRegionNumber() << "\n";
28       for (Block &block : region) {
29         os << "  ";
30         block.printAsOperand(os);
31         os << " = ";
32         auto *live = solver.lookupState<Executable>(
33             solver.getProgramPointBefore(&block));
34         if (live)
35           os << *live;
36         else
37           os << "dead";
38         os << "\n";
39         for (Block *pred : block.getPredecessors()) {
40           os << "   from ";
41           pred->printAsOperand(os);
42           os << " = ";
43           auto *live = solver.lookupState<Executable>(
44               solver.getLatticeAnchor<CFGEdge>(pred, &block));
45           if (live)
46             os << *live;
47           else
48             os << "dead";
49           os << "\n";
50         }
51       }
52       if (!region.empty()) {
53         auto *preds = solver.lookupState<PredecessorState>(
54             solver.getProgramPointBefore(&region.front()));
55         if (preds)
56           os << "region_preds: " << *preds << "\n";
57       }
58     }
59     auto *preds =
60         solver.lookupState<PredecessorState>(solver.getProgramPointAfter(op));
61     if (preds)
62       os << "op_preds: " << *preds << "\n";
63   });
64 }
65 
66 namespace {
67 /// This is a simple analysis that implements a transfer function for constant
68 /// operations.
69 struct ConstantAnalysis : public DataFlowAnalysis {
70   using DataFlowAnalysis::DataFlowAnalysis;
71 
72   LogicalResult initialize(Operation *top) override {
73     WalkResult result = top->walk([&](Operation *op) {
74       if (failed(visit(getProgramPointAfter(op))))
75         return WalkResult::interrupt();
76       return WalkResult::advance();
77     });
78     return success(!result.wasInterrupted());
79   }
80 
81   LogicalResult visit(ProgramPoint *point) override {
82     Operation *op = point->getPrevOp();
83     Attribute value;
84     if (matchPattern(op, m_Constant(&value))) {
85       auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
86       propagateIfChanged(
87           constant, constant->join(ConstantValue(value, op->getDialect())));
88       return success();
89     }
90     setAllToUnknownConstants(op->getResults());
91     for (Region &region : op->getRegions())
92       setAllToUnknownConstants(region.getArguments());
93     return success();
94   }
95 
96   /// Set all given values as not constants.
97   void setAllToUnknownConstants(ValueRange values) {
98     for (Value value : values) {
99       auto *constant = getOrCreate<Lattice<ConstantValue>>(value);
100       propagateIfChanged(constant,
101                          constant->join(ConstantValue::getUnknownConstant()));
102     }
103   }
104 };
105 
106 /// This is a simple pass that runs dead code analysis with a constant value
107 /// provider that only understands constant operations.
108 struct TestDeadCodeAnalysisPass
109     : public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
110   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
111 
112   StringRef getArgument() const override { return "test-dead-code-analysis"; }
113 
114   void runOnOperation() override {
115     Operation *op = getOperation();
116 
117     DataFlowSolver solver;
118     solver.load<DeadCodeAnalysis>();
119     solver.load<ConstantAnalysis>();
120     if (failed(solver.initializeAndRun(op)))
121       return signalPassFailure();
122     printAnalysisResults(solver, op, llvm::errs());
123   }
124 };
125 } // end anonymous namespace
126 
127 namespace mlir {
128 namespace test {
129 void registerTestDeadCodeAnalysisPass() {
130   PassRegistration<TestDeadCodeAnalysisPass>();
131 }
132 } // end namespace test
133 } // end namespace mlir
134