xref: /llvm-project/mlir/test/lib/Analysis/TestDataFlowFramework.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1ead75d94SMogball //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2ead75d94SMogball //
3ead75d94SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ead75d94SMogball // See https://llvm.org/LICENSE.txt for license information.
5ead75d94SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ead75d94SMogball //
7ead75d94SMogball //===----------------------------------------------------------------------===//
8ead75d94SMogball 
9ead75d94SMogball #include "mlir/Analysis/DataFlowFramework.h"
10ead75d94SMogball #include "mlir/Dialect/Func/IR/FuncOps.h"
11ead75d94SMogball #include "mlir/Pass/Pass.h"
12a1fe1f5fSKazu Hirata #include <optional>
13ead75d94SMogball 
14ead75d94SMogball using namespace mlir;
15ead75d94SMogball 
16ead75d94SMogball namespace {
17ead75d94SMogball /// This analysis state represents an integer that is XOR'd with other states.
18ead75d94SMogball class FooState : public AnalysisState {
19ead75d94SMogball public:
20ead75d94SMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
21ead75d94SMogball 
22ead75d94SMogball   using AnalysisState::AnalysisState;
23ead75d94SMogball 
24ead75d94SMogball   /// Returns true if the state is uninitialized.
2547bf3e38SZhixun Tan   bool isUninitialized() const { return !state; }
26ead75d94SMogball 
27ead75d94SMogball   /// Print the integer value or "none" if uninitialized.
28ead75d94SMogball   void print(raw_ostream &os) const override {
29ead75d94SMogball     if (state)
30ead75d94SMogball       os << *state;
31ead75d94SMogball     else
32ead75d94SMogball       os << "none";
33ead75d94SMogball   }
34ead75d94SMogball 
35ead75d94SMogball   /// Join the state with another. If either is unintialized, take the
36ead75d94SMogball   /// initialized value. Otherwise, XOR the integer values.
37ead75d94SMogball   ChangeResult join(const FooState &rhs) {
38ead75d94SMogball     if (rhs.isUninitialized())
39ead75d94SMogball       return ChangeResult::NoChange;
40ead75d94SMogball     return join(*rhs.state);
41ead75d94SMogball   }
42ead75d94SMogball   ChangeResult join(uint64_t value) {
43ead75d94SMogball     if (isUninitialized()) {
44ead75d94SMogball       state = value;
45ead75d94SMogball       return ChangeResult::Change;
46ead75d94SMogball     }
47ead75d94SMogball     uint64_t before = *state;
48ead75d94SMogball     state = before ^ value;
49ead75d94SMogball     return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
50ead75d94SMogball   }
51ead75d94SMogball 
52ead75d94SMogball   /// Set the value of the state directly.
53ead75d94SMogball   ChangeResult set(const FooState &rhs) {
54ead75d94SMogball     if (state == rhs.state)
55ead75d94SMogball       return ChangeResult::NoChange;
56ead75d94SMogball     state = rhs.state;
57ead75d94SMogball     return ChangeResult::Change;
58ead75d94SMogball   }
59ead75d94SMogball 
60ead75d94SMogball   /// Returns the integer value of the state.
61ead75d94SMogball   uint64_t getValue() const { return *state; }
62ead75d94SMogball 
63ead75d94SMogball private:
64ead75d94SMogball   /// An optional integer value.
650a81ace0SKazu Hirata   std::optional<uint64_t> state;
66ead75d94SMogball };
67ead75d94SMogball 
68ead75d94SMogball /// This analysis computes `FooState` across operations and control-flow edges.
69ead75d94SMogball /// If an op specifies a `foo` integer attribute, the contained value is XOR'd
70ead75d94SMogball /// with the value before the operation.
71ead75d94SMogball class FooAnalysis : public DataFlowAnalysis {
72ead75d94SMogball public:
73ead75d94SMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
74ead75d94SMogball 
75ead75d94SMogball   using DataFlowAnalysis::DataFlowAnalysis;
76ead75d94SMogball 
77ead75d94SMogball   LogicalResult initialize(Operation *top) override;
78*4b3f251bSdonald chen   LogicalResult visit(ProgramPoint *point) override;
79ead75d94SMogball 
80ead75d94SMogball private:
81ead75d94SMogball   void visitBlock(Block *block);
82ead75d94SMogball   void visitOperation(Operation *op);
83ead75d94SMogball };
84ead75d94SMogball 
85ead75d94SMogball struct TestFooAnalysisPass
86ead75d94SMogball     : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
87ead75d94SMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
88ead75d94SMogball 
89ead75d94SMogball   StringRef getArgument() const override { return "test-foo-analysis"; }
90ead75d94SMogball 
91ead75d94SMogball   void runOnOperation() override;
92ead75d94SMogball };
93ead75d94SMogball } // namespace
94ead75d94SMogball 
95ead75d94SMogball LogicalResult FooAnalysis::initialize(Operation *top) {
96ead75d94SMogball   if (top->getNumRegions() != 1)
97ead75d94SMogball     return top->emitError("expected a single region top-level op");
98ead75d94SMogball 
99de58a4f1SKai Sasaki   if (top->getRegion(0).getBlocks().empty())
100de58a4f1SKai Sasaki     return top->emitError("expected at least one block in the region");
101de58a4f1SKai Sasaki 
102ead75d94SMogball   // Initialize the top-level state.
103*4b3f251bSdonald chen   (void)getOrCreate<FooState>(getProgramPointBefore(&top->getRegion(0).front()))
104*4b3f251bSdonald chen       ->join(0);
105ead75d94SMogball 
106ead75d94SMogball   // Visit all nested blocks and operations.
107ead75d94SMogball   for (Block &block : top->getRegion(0)) {
108ead75d94SMogball     visitBlock(&block);
109ead75d94SMogball     for (Operation &op : block) {
110ead75d94SMogball       if (op.getNumRegions())
111ead75d94SMogball         return op.emitError("unexpected op with regions");
112ead75d94SMogball       visitOperation(&op);
113ead75d94SMogball     }
114ead75d94SMogball   }
115ead75d94SMogball   return success();
116ead75d94SMogball }
117ead75d94SMogball 
118*4b3f251bSdonald chen LogicalResult FooAnalysis::visit(ProgramPoint *point) {
119*4b3f251bSdonald chen   if (!point->isBlockStart())
120*4b3f251bSdonald chen     visitOperation(point->getPrevOp());
121b6603e1bSdonald chen   else
122*4b3f251bSdonald chen     visitBlock(point->getBlock());
123ead75d94SMogball   return success();
124ead75d94SMogball }
125ead75d94SMogball 
126ead75d94SMogball void FooAnalysis::visitBlock(Block *block) {
127ead75d94SMogball   if (block->isEntryBlock()) {
128ead75d94SMogball     // This is the initial state. Let the framework default-initialize it.
129ead75d94SMogball     return;
130ead75d94SMogball   }
131*4b3f251bSdonald chen   ProgramPoint *point = getProgramPointBefore(block);
132*4b3f251bSdonald chen   FooState *state = getOrCreate<FooState>(point);
133ead75d94SMogball   ChangeResult result = ChangeResult::NoChange;
134ead75d94SMogball   for (Block *pred : block->getPredecessors()) {
135ead75d94SMogball     // Join the state at the terminators of all predecessors.
136*4b3f251bSdonald chen     const FooState *predState = getOrCreateFor<FooState>(
137*4b3f251bSdonald chen         point, getProgramPointAfter(pred->getTerminator()));
138ead75d94SMogball     result |= state->join(*predState);
139ead75d94SMogball   }
140ead75d94SMogball   propagateIfChanged(state, result);
141ead75d94SMogball }
142ead75d94SMogball 
143ead75d94SMogball void FooAnalysis::visitOperation(Operation *op) {
144*4b3f251bSdonald chen   ProgramPoint *point = getProgramPointAfter(op);
145*4b3f251bSdonald chen   FooState *state = getOrCreate<FooState>(point);
146ead75d94SMogball   ChangeResult result = ChangeResult::NoChange;
147ead75d94SMogball 
148ead75d94SMogball   // Copy the state across the operation.
149ead75d94SMogball   const FooState *prevState;
150*4b3f251bSdonald chen   prevState = getOrCreateFor<FooState>(point, getProgramPointBefore(op));
151ead75d94SMogball   result |= state->set(*prevState);
152ead75d94SMogball 
153ead75d94SMogball   // Modify the state with the attribute, if specified.
154ead75d94SMogball   if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
155ead75d94SMogball     uint64_t value = attr.getUInt();
156ead75d94SMogball     result |= state->join(value);
157ead75d94SMogball   }
158ead75d94SMogball   propagateIfChanged(state, result);
159ead75d94SMogball }
160ead75d94SMogball 
161ead75d94SMogball void TestFooAnalysisPass::runOnOperation() {
162ead75d94SMogball   func::FuncOp func = getOperation();
163ead75d94SMogball   DataFlowSolver solver;
164ead75d94SMogball   solver.load<FooAnalysis>();
165ead75d94SMogball   if (failed(solver.initializeAndRun(func)))
166ead75d94SMogball     return signalPassFailure();
167ead75d94SMogball 
168ead75d94SMogball   raw_ostream &os = llvm::errs();
169ead75d94SMogball   os << "function: @" << func.getSymName() << "\n";
170ead75d94SMogball 
171ead75d94SMogball   func.walk([&](Operation *op) {
172ead75d94SMogball     auto tag = op->getAttrOfType<StringAttr>("tag");
173ead75d94SMogball     if (!tag)
174ead75d94SMogball       return;
175*4b3f251bSdonald chen     const FooState *state =
176*4b3f251bSdonald chen         solver.lookupState<FooState>(solver.getProgramPointAfter(op));
177ead75d94SMogball     assert(state && !state->isUninitialized());
178ead75d94SMogball     os << tag.getValue() << " -> " << state->getValue() << "\n";
179ead75d94SMogball   });
180ead75d94SMogball }
181ead75d94SMogball 
182ead75d94SMogball namespace mlir {
183ead75d94SMogball namespace test {
184ead75d94SMogball void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
185ead75d94SMogball } // namespace test
186ead75d94SMogball } // namespace mlir
187