xref: /llvm-project/mlir/test/lib/Analysis/TestDataFlowFramework.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1 //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Pass/Pass.h"
12 #include <optional>
13 
14 using namespace mlir;
15 
16 namespace {
17 /// This analysis state represents an integer that is XOR'd with other states.
18 class FooState : public AnalysisState {
19 public:
20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
21 
22   using AnalysisState::AnalysisState;
23 
24   /// Returns true if the state is uninitialized.
25   bool isUninitialized() const { return !state; }
26 
27   /// Print the integer value or "none" if uninitialized.
28   void print(raw_ostream &os) const override {
29     if (state)
30       os << *state;
31     else
32       os << "none";
33   }
34 
35   /// Join the state with another. If either is unintialized, take the
36   /// initialized value. Otherwise, XOR the integer values.
37   ChangeResult join(const FooState &rhs) {
38     if (rhs.isUninitialized())
39       return ChangeResult::NoChange;
40     return join(*rhs.state);
41   }
42   ChangeResult join(uint64_t value) {
43     if (isUninitialized()) {
44       state = value;
45       return ChangeResult::Change;
46     }
47     uint64_t before = *state;
48     state = before ^ value;
49     return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
50   }
51 
52   /// Set the value of the state directly.
53   ChangeResult set(const FooState &rhs) {
54     if (state == rhs.state)
55       return ChangeResult::NoChange;
56     state = rhs.state;
57     return ChangeResult::Change;
58   }
59 
60   /// Returns the integer value of the state.
61   uint64_t getValue() const { return *state; }
62 
63 private:
64   /// An optional integer value.
65   std::optional<uint64_t> state;
66 };
67 
68 /// This analysis computes `FooState` across operations and control-flow edges.
69 /// If an op specifies a `foo` integer attribute, the contained value is XOR'd
70 /// with the value before the operation.
71 class FooAnalysis : public DataFlowAnalysis {
72 public:
73   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
74 
75   using DataFlowAnalysis::DataFlowAnalysis;
76 
77   LogicalResult initialize(Operation *top) override;
78   LogicalResult visit(ProgramPoint *point) override;
79 
80 private:
81   void visitBlock(Block *block);
82   void visitOperation(Operation *op);
83 };
84 
85 struct TestFooAnalysisPass
86     : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
87   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
88 
89   StringRef getArgument() const override { return "test-foo-analysis"; }
90 
91   void runOnOperation() override;
92 };
93 } // namespace
94 
95 LogicalResult FooAnalysis::initialize(Operation *top) {
96   if (top->getNumRegions() != 1)
97     return top->emitError("expected a single region top-level op");
98 
99   if (top->getRegion(0).getBlocks().empty())
100     return top->emitError("expected at least one block in the region");
101 
102   // Initialize the top-level state.
103   (void)getOrCreate<FooState>(getProgramPointBefore(&top->getRegion(0).front()))
104       ->join(0);
105 
106   // Visit all nested blocks and operations.
107   for (Block &block : top->getRegion(0)) {
108     visitBlock(&block);
109     for (Operation &op : block) {
110       if (op.getNumRegions())
111         return op.emitError("unexpected op with regions");
112       visitOperation(&op);
113     }
114   }
115   return success();
116 }
117 
118 LogicalResult FooAnalysis::visit(ProgramPoint *point) {
119   if (!point->isBlockStart())
120     visitOperation(point->getPrevOp());
121   else
122     visitBlock(point->getBlock());
123   return success();
124 }
125 
126 void FooAnalysis::visitBlock(Block *block) {
127   if (block->isEntryBlock()) {
128     // This is the initial state. Let the framework default-initialize it.
129     return;
130   }
131   ProgramPoint *point = getProgramPointBefore(block);
132   FooState *state = getOrCreate<FooState>(point);
133   ChangeResult result = ChangeResult::NoChange;
134   for (Block *pred : block->getPredecessors()) {
135     // Join the state at the terminators of all predecessors.
136     const FooState *predState = getOrCreateFor<FooState>(
137         point, getProgramPointAfter(pred->getTerminator()));
138     result |= state->join(*predState);
139   }
140   propagateIfChanged(state, result);
141 }
142 
143 void FooAnalysis::visitOperation(Operation *op) {
144   ProgramPoint *point = getProgramPointAfter(op);
145   FooState *state = getOrCreate<FooState>(point);
146   ChangeResult result = ChangeResult::NoChange;
147 
148   // Copy the state across the operation.
149   const FooState *prevState;
150   prevState = getOrCreateFor<FooState>(point, getProgramPointBefore(op));
151   result |= state->set(*prevState);
152 
153   // Modify the state with the attribute, if specified.
154   if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
155     uint64_t value = attr.getUInt();
156     result |= state->join(value);
157   }
158   propagateIfChanged(state, result);
159 }
160 
161 void TestFooAnalysisPass::runOnOperation() {
162   func::FuncOp func = getOperation();
163   DataFlowSolver solver;
164   solver.load<FooAnalysis>();
165   if (failed(solver.initializeAndRun(func)))
166     return signalPassFailure();
167 
168   raw_ostream &os = llvm::errs();
169   os << "function: @" << func.getSymName() << "\n";
170 
171   func.walk([&](Operation *op) {
172     auto tag = op->getAttrOfType<StringAttr>("tag");
173     if (!tag)
174       return;
175     const FooState *state =
176         solver.lookupState<FooState>(solver.getProgramPointAfter(op));
177     assert(state && !state->isUninitialized());
178     os << tag.getValue() << " -> " << state->getValue() << "\n";
179   });
180 }
181 
182 namespace mlir {
183 namespace test {
184 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
185 } // namespace test
186 } // namespace mlir
187