xref: /llvm-project/mlir/lib/Analysis/DataFlowFramework.cpp (revision 3a439e2caf0bb545ee451df1de5b02ea068140f7)
1 //===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "mlir/IR/Location.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/Value.h"
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/iterator.h"
15 #include "llvm/Config/abi-breaking.h"
16 #include "llvm/Support/Casting.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 #define DEBUG_TYPE "dataflow"
21 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
22 #define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
23 #else
24 #define DATAFLOW_DEBUG(X)
25 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // GenericLatticeAnchor
31 //===----------------------------------------------------------------------===//
32 
33 GenericLatticeAnchor::~GenericLatticeAnchor() = default;
34 
35 //===----------------------------------------------------------------------===//
36 // AnalysisState
37 //===----------------------------------------------------------------------===//
38 
39 AnalysisState::~AnalysisState() = default;
40 
41 void AnalysisState::addDependency(ProgramPoint *dependent,
42                                   DataFlowAnalysis *analysis) {
43   auto inserted = dependents.insert({dependent, analysis});
44   (void)inserted;
45   DATAFLOW_DEBUG({
46     if (inserted) {
47       llvm::dbgs() << "Creating dependency between " << debugName << " of "
48                    << anchor << "\nand " << debugName << " on " << dependent
49                    << "\n";
50     }
51   });
52 }
53 
54 void AnalysisState::dump() const { print(llvm::errs()); }
55 
56 //===----------------------------------------------------------------------===//
57 // ProgramPoint
58 //===----------------------------------------------------------------------===//
59 
60 void ProgramPoint::print(raw_ostream &os) const {
61   if (isNull()) {
62     os << "<NULL POINT>";
63     return;
64   }
65   if (!isBlockStart()) {
66     os << "<after operation>:";
67     return getPrevOp()->print(os, OpPrintingFlags().skipRegions());
68   }
69   os << "<before operation>:";
70   return getNextOp()->print(os, OpPrintingFlags().skipRegions());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // LatticeAnchor
75 //===----------------------------------------------------------------------===//
76 
77 void LatticeAnchor::print(raw_ostream &os) const {
78   if (isNull()) {
79     os << "<NULL POINT>";
80     return;
81   }
82   if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
83     return LatticeAnchor->print(os);
84   if (auto value = llvm::dyn_cast<Value>(*this)) {
85     return value.print(os, OpPrintingFlags().skipRegions());
86   }
87 
88   return llvm::cast<ProgramPoint *>(*this)->print(os);
89 }
90 
91 Location LatticeAnchor::getLoc() const {
92   if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
93     return LatticeAnchor->getLoc();
94   if (auto value = llvm::dyn_cast<Value>(*this))
95     return value.getLoc();
96 
97   ProgramPoint *pp = llvm::cast<ProgramPoint *>(*this);
98   if (!pp->isBlockStart())
99     return pp->getPrevOp()->getLoc();
100   return pp->getBlock()->getParent()->getLoc();
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // DataFlowSolver
105 //===----------------------------------------------------------------------===//
106 
107 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
108   // Enable enqueue to the worklist.
109   isRunning = true;
110   auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
111 
112   // Initialize the analyses.
113   for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
114     DATAFLOW_DEBUG(llvm::dbgs()
115                    << "Priming analysis: " << analysis.debugName << "\n");
116     if (failed(analysis.initialize(top)))
117       return failure();
118   }
119 
120   // Run the analysis until fixpoint.
121   do {
122     // Exhaust the worklist.
123     while (!worklist.empty()) {
124       auto [point, analysis] = worklist.front();
125       worklist.pop();
126 
127       DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
128                                   << "' on: " << point << "\n");
129       if (failed(analysis->visit(point)))
130         return failure();
131     }
132 
133     // Iterate until all states are in some initialized state and the worklist
134     // is exhausted.
135   } while (!worklist.empty());
136 
137   return success();
138 }
139 
140 void DataFlowSolver::propagateIfChanged(AnalysisState *state,
141                                         ChangeResult changed) {
142   assert(isRunning &&
143          "DataFlowSolver is not running, should not use propagateIfChanged");
144   if (changed == ChangeResult::Change) {
145     DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
146                                 << " of " << state->anchor << "\n"
147                                 << "Value: " << *state << "\n");
148     state->onUpdate(this);
149   }
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // DataFlowAnalysis
154 //===----------------------------------------------------------------------===//
155 
156 DataFlowAnalysis::~DataFlowAnalysis() = default;
157 
158 DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
159 
160 void DataFlowAnalysis::addDependency(AnalysisState *state,
161                                      ProgramPoint *point) {
162   state->addDependency(point, this);
163 }
164 
165 void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
166                                           ChangeResult changed) {
167   solver.propagateIfChanged(state, changed);
168 }
169