xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1c095afcbSMogball //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2c095afcbSMogball //
3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information.
5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c095afcbSMogball //
7c095afcbSMogball //===----------------------------------------------------------------------===//
8c095afcbSMogball 
9c095afcbSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10c095afcbSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11c095afcbSMogball #include "mlir/IR/Matchers.h"
12c095afcbSMogball #include "mlir/Pass/Pass.h"
13c095afcbSMogball 
14c095afcbSMogball using namespace mlir;
15c095afcbSMogball using namespace mlir::dataflow;
16c095afcbSMogball 
17c095afcbSMogball /// Print the liveness of every block, control-flow edge, and the predecessors
18c095afcbSMogball /// of all regions, callables, and calls.
19c095afcbSMogball static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
20c095afcbSMogball                                  raw_ostream &os) {
21c095afcbSMogball   op->walk([&](Operation *op) {
22c095afcbSMogball     auto tag = op->getAttrOfType<StringAttr>("tag");
23c095afcbSMogball     if (!tag)
24c095afcbSMogball       return;
25c095afcbSMogball     os << tag.getValue() << ":\n";
26c095afcbSMogball     for (Region &region : op->getRegions()) {
27c095afcbSMogball       os << " region #" << region.getRegionNumber() << "\n";
28c095afcbSMogball       for (Block &block : region) {
29c095afcbSMogball         os << "  ";
30c095afcbSMogball         block.printAsOperand(os);
31c095afcbSMogball         os << " = ";
32*4b3f251bSdonald chen         auto *live = solver.lookupState<Executable>(
33*4b3f251bSdonald chen             solver.getProgramPointBefore(&block));
34c095afcbSMogball         if (live)
35c095afcbSMogball           os << *live;
36c095afcbSMogball         else
37c095afcbSMogball           os << "dead";
38c095afcbSMogball         os << "\n";
39c095afcbSMogball         for (Block *pred : block.getPredecessors()) {
40c095afcbSMogball           os << "   from ";
41c095afcbSMogball           pred->printAsOperand(os);
42c095afcbSMogball           os << " = ";
43c095afcbSMogball           auto *live = solver.lookupState<Executable>(
44b6603e1bSdonald chen               solver.getLatticeAnchor<CFGEdge>(pred, &block));
45c095afcbSMogball           if (live)
46c095afcbSMogball             os << *live;
47c095afcbSMogball           else
48c095afcbSMogball             os << "dead";
49c095afcbSMogball           os << "\n";
50c095afcbSMogball         }
51c095afcbSMogball       }
52c095afcbSMogball       if (!region.empty()) {
53*4b3f251bSdonald chen         auto *preds = solver.lookupState<PredecessorState>(
54*4b3f251bSdonald chen             solver.getProgramPointBefore(&region.front()));
55c095afcbSMogball         if (preds)
56c095afcbSMogball           os << "region_preds: " << *preds << "\n";
57c095afcbSMogball       }
58c095afcbSMogball     }
59*4b3f251bSdonald chen     auto *preds =
60*4b3f251bSdonald chen         solver.lookupState<PredecessorState>(solver.getProgramPointAfter(op));
61c095afcbSMogball     if (preds)
62c095afcbSMogball       os << "op_preds: " << *preds << "\n";
63c095afcbSMogball   });
64c095afcbSMogball }
65c095afcbSMogball 
66c095afcbSMogball namespace {
67c095afcbSMogball /// This is a simple analysis that implements a transfer function for constant
68c095afcbSMogball /// operations.
69c095afcbSMogball struct ConstantAnalysis : public DataFlowAnalysis {
70c095afcbSMogball   using DataFlowAnalysis::DataFlowAnalysis;
71c095afcbSMogball 
72c095afcbSMogball   LogicalResult initialize(Operation *top) override {
73c095afcbSMogball     WalkResult result = top->walk([&](Operation *op) {
74*4b3f251bSdonald chen       if (failed(visit(getProgramPointAfter(op))))
75c095afcbSMogball         return WalkResult::interrupt();
76c095afcbSMogball       return WalkResult::advance();
77c095afcbSMogball     });
78c095afcbSMogball     return success(!result.wasInterrupted());
79c095afcbSMogball   }
80c095afcbSMogball 
81*4b3f251bSdonald chen   LogicalResult visit(ProgramPoint *point) override {
82*4b3f251bSdonald chen     Operation *op = point->getPrevOp();
83c095afcbSMogball     Attribute value;
84c095afcbSMogball     if (matchPattern(op, m_Constant(&value))) {
85c095afcbSMogball       auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
86c095afcbSMogball       propagateIfChanged(
87c095afcbSMogball           constant, constant->join(ConstantValue(value, op->getDialect())));
88c095afcbSMogball       return success();
89c095afcbSMogball     }
90de0ebc52SZhixun Tan     setAllToUnknownConstants(op->getResults());
91ab701975SMogball     for (Region &region : op->getRegions())
92de0ebc52SZhixun Tan       setAllToUnknownConstants(region.getArguments());
93ab701975SMogball     return success();
94ab701975SMogball   }
95ab701975SMogball 
96de0ebc52SZhixun Tan   /// Set all given values as not constants.
97de0ebc52SZhixun Tan   void setAllToUnknownConstants(ValueRange values) {
98ab701975SMogball     for (Value value : values) {
99de0ebc52SZhixun Tan       auto *constant = getOrCreate<Lattice<ConstantValue>>(value);
100de0ebc52SZhixun Tan       propagateIfChanged(constant,
101de0ebc52SZhixun Tan                          constant->join(ConstantValue::getUnknownConstant()));
102ab701975SMogball     }
103ab701975SMogball   }
104c095afcbSMogball };
105c095afcbSMogball 
106ab701975SMogball /// This is a simple pass that runs dead code analysis with a constant value
107ab701975SMogball /// provider that only understands constant operations.
108c095afcbSMogball struct TestDeadCodeAnalysisPass
109c095afcbSMogball     : public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
110c095afcbSMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
111c095afcbSMogball 
112c095afcbSMogball   StringRef getArgument() const override { return "test-dead-code-analysis"; }
113c095afcbSMogball 
114c095afcbSMogball   void runOnOperation() override {
115c095afcbSMogball     Operation *op = getOperation();
116c095afcbSMogball 
117c095afcbSMogball     DataFlowSolver solver;
118c095afcbSMogball     solver.load<DeadCodeAnalysis>();
119c095afcbSMogball     solver.load<ConstantAnalysis>();
120c095afcbSMogball     if (failed(solver.initializeAndRun(op)))
121c095afcbSMogball       return signalPassFailure();
122c095afcbSMogball     printAnalysisResults(solver, op, llvm::errs());
123c095afcbSMogball   }
124c095afcbSMogball };
125c095afcbSMogball } // end anonymous namespace
126c095afcbSMogball 
127c095afcbSMogball namespace mlir {
128c095afcbSMogball namespace test {
129c095afcbSMogball void registerTestDeadCodeAnalysisPass() {
130c095afcbSMogball   PassRegistration<TestDeadCodeAnalysisPass>();
131c095afcbSMogball }
132c095afcbSMogball } // end namespace test
133c095afcbSMogball } // end namespace mlir
134