1 //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlowFramework.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Pass/Pass.h" 12 #include <optional> 13 14 using namespace mlir; 15 16 namespace { 17 /// This analysis state represents an integer that is XOR'd with other states. 18 class FooState : public AnalysisState { 19 public: 20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) 21 22 using AnalysisState::AnalysisState; 23 24 /// Returns true if the state is uninitialized. 25 bool isUninitialized() const { return !state; } 26 27 /// Print the integer value or "none" if uninitialized. 28 void print(raw_ostream &os) const override { 29 if (state) 30 os << *state; 31 else 32 os << "none"; 33 } 34 35 /// Join the state with another. If either is unintialized, take the 36 /// initialized value. Otherwise, XOR the integer values. 37 ChangeResult join(const FooState &rhs) { 38 if (rhs.isUninitialized()) 39 return ChangeResult::NoChange; 40 return join(*rhs.state); 41 } 42 ChangeResult join(uint64_t value) { 43 if (isUninitialized()) { 44 state = value; 45 return ChangeResult::Change; 46 } 47 uint64_t before = *state; 48 state = before ^ value; 49 return before == *state ? ChangeResult::NoChange : ChangeResult::Change; 50 } 51 52 /// Set the value of the state directly. 53 ChangeResult set(const FooState &rhs) { 54 if (state == rhs.state) 55 return ChangeResult::NoChange; 56 state = rhs.state; 57 return ChangeResult::Change; 58 } 59 60 /// Returns the integer value of the state. 61 uint64_t getValue() const { return *state; } 62 63 private: 64 /// An optional integer value. 65 std::optional<uint64_t> state; 66 }; 67 68 /// This analysis computes `FooState` across operations and control-flow edges. 69 /// If an op specifies a `foo` integer attribute, the contained value is XOR'd 70 /// with the value before the operation. 71 class FooAnalysis : public DataFlowAnalysis { 72 public: 73 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis) 74 75 using DataFlowAnalysis::DataFlowAnalysis; 76 77 LogicalResult initialize(Operation *top) override; 78 LogicalResult visit(ProgramPoint *point) override; 79 80 private: 81 void visitBlock(Block *block); 82 void visitOperation(Operation *op); 83 }; 84 85 struct TestFooAnalysisPass 86 : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> { 87 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass) 88 89 StringRef getArgument() const override { return "test-foo-analysis"; } 90 91 void runOnOperation() override; 92 }; 93 } // namespace 94 95 LogicalResult FooAnalysis::initialize(Operation *top) { 96 if (top->getNumRegions() != 1) 97 return top->emitError("expected a single region top-level op"); 98 99 if (top->getRegion(0).getBlocks().empty()) 100 return top->emitError("expected at least one block in the region"); 101 102 // Initialize the top-level state. 103 (void)getOrCreate<FooState>(getProgramPointBefore(&top->getRegion(0).front())) 104 ->join(0); 105 106 // Visit all nested blocks and operations. 107 for (Block &block : top->getRegion(0)) { 108 visitBlock(&block); 109 for (Operation &op : block) { 110 if (op.getNumRegions()) 111 return op.emitError("unexpected op with regions"); 112 visitOperation(&op); 113 } 114 } 115 return success(); 116 } 117 118 LogicalResult FooAnalysis::visit(ProgramPoint *point) { 119 if (!point->isBlockStart()) 120 visitOperation(point->getPrevOp()); 121 else 122 visitBlock(point->getBlock()); 123 return success(); 124 } 125 126 void FooAnalysis::visitBlock(Block *block) { 127 if (block->isEntryBlock()) { 128 // This is the initial state. Let the framework default-initialize it. 129 return; 130 } 131 ProgramPoint *point = getProgramPointBefore(block); 132 FooState *state = getOrCreate<FooState>(point); 133 ChangeResult result = ChangeResult::NoChange; 134 for (Block *pred : block->getPredecessors()) { 135 // Join the state at the terminators of all predecessors. 136 const FooState *predState = getOrCreateFor<FooState>( 137 point, getProgramPointAfter(pred->getTerminator())); 138 result |= state->join(*predState); 139 } 140 propagateIfChanged(state, result); 141 } 142 143 void FooAnalysis::visitOperation(Operation *op) { 144 ProgramPoint *point = getProgramPointAfter(op); 145 FooState *state = getOrCreate<FooState>(point); 146 ChangeResult result = ChangeResult::NoChange; 147 148 // Copy the state across the operation. 149 const FooState *prevState; 150 prevState = getOrCreateFor<FooState>(point, getProgramPointBefore(op)); 151 result |= state->set(*prevState); 152 153 // Modify the state with the attribute, if specified. 154 if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) { 155 uint64_t value = attr.getUInt(); 156 result |= state->join(value); 157 } 158 propagateIfChanged(state, result); 159 } 160 161 void TestFooAnalysisPass::runOnOperation() { 162 func::FuncOp func = getOperation(); 163 DataFlowSolver solver; 164 solver.load<FooAnalysis>(); 165 if (failed(solver.initializeAndRun(func))) 166 return signalPassFailure(); 167 168 raw_ostream &os = llvm::errs(); 169 os << "function: @" << func.getSymName() << "\n"; 170 171 func.walk([&](Operation *op) { 172 auto tag = op->getAttrOfType<StringAttr>("tag"); 173 if (!tag) 174 return; 175 const FooState *state = 176 solver.lookupState<FooState>(solver.getProgramPointAfter(op)); 177 assert(state && !state->isUninitialized()); 178 os << tag.getValue() << " -> " << state->getValue() << "\n"; 179 }); 180 } 181 182 namespace mlir { 183 namespace test { 184 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); } 185 } // namespace test 186 } // namespace mlir 187