1ead75d94SMogball //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===// 2ead75d94SMogball // 3ead75d94SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ead75d94SMogball // See https://llvm.org/LICENSE.txt for license information. 5ead75d94SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ead75d94SMogball // 7ead75d94SMogball //===----------------------------------------------------------------------===// 8ead75d94SMogball 9ead75d94SMogball #include "mlir/Analysis/DataFlowFramework.h" 10ead75d94SMogball #include "mlir/Dialect/Func/IR/FuncOps.h" 11ead75d94SMogball #include "mlir/Pass/Pass.h" 12a1fe1f5fSKazu Hirata #include <optional> 13ead75d94SMogball 14ead75d94SMogball using namespace mlir; 15ead75d94SMogball 16ead75d94SMogball namespace { 17ead75d94SMogball /// This analysis state represents an integer that is XOR'd with other states. 18ead75d94SMogball class FooState : public AnalysisState { 19ead75d94SMogball public: 20ead75d94SMogball MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) 21ead75d94SMogball 22ead75d94SMogball using AnalysisState::AnalysisState; 23ead75d94SMogball 24ead75d94SMogball /// Returns true if the state is uninitialized. 2547bf3e38SZhixun Tan bool isUninitialized() const { return !state; } 26ead75d94SMogball 27ead75d94SMogball /// Print the integer value or "none" if uninitialized. 28ead75d94SMogball void print(raw_ostream &os) const override { 29ead75d94SMogball if (state) 30ead75d94SMogball os << *state; 31ead75d94SMogball else 32ead75d94SMogball os << "none"; 33ead75d94SMogball } 34ead75d94SMogball 35ead75d94SMogball /// Join the state with another. If either is unintialized, take the 36ead75d94SMogball /// initialized value. Otherwise, XOR the integer values. 37ead75d94SMogball ChangeResult join(const FooState &rhs) { 38ead75d94SMogball if (rhs.isUninitialized()) 39ead75d94SMogball return ChangeResult::NoChange; 40ead75d94SMogball return join(*rhs.state); 41ead75d94SMogball } 42ead75d94SMogball ChangeResult join(uint64_t value) { 43ead75d94SMogball if (isUninitialized()) { 44ead75d94SMogball state = value; 45ead75d94SMogball return ChangeResult::Change; 46ead75d94SMogball } 47ead75d94SMogball uint64_t before = *state; 48ead75d94SMogball state = before ^ value; 49ead75d94SMogball return before == *state ? ChangeResult::NoChange : ChangeResult::Change; 50ead75d94SMogball } 51ead75d94SMogball 52ead75d94SMogball /// Set the value of the state directly. 53ead75d94SMogball ChangeResult set(const FooState &rhs) { 54ead75d94SMogball if (state == rhs.state) 55ead75d94SMogball return ChangeResult::NoChange; 56ead75d94SMogball state = rhs.state; 57ead75d94SMogball return ChangeResult::Change; 58ead75d94SMogball } 59ead75d94SMogball 60ead75d94SMogball /// Returns the integer value of the state. 61ead75d94SMogball uint64_t getValue() const { return *state; } 62ead75d94SMogball 63ead75d94SMogball private: 64ead75d94SMogball /// An optional integer value. 650a81ace0SKazu Hirata std::optional<uint64_t> state; 66ead75d94SMogball }; 67ead75d94SMogball 68ead75d94SMogball /// This analysis computes `FooState` across operations and control-flow edges. 69ead75d94SMogball /// If an op specifies a `foo` integer attribute, the contained value is XOR'd 70ead75d94SMogball /// with the value before the operation. 71ead75d94SMogball class FooAnalysis : public DataFlowAnalysis { 72ead75d94SMogball public: 73ead75d94SMogball MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis) 74ead75d94SMogball 75ead75d94SMogball using DataFlowAnalysis::DataFlowAnalysis; 76ead75d94SMogball 77ead75d94SMogball LogicalResult initialize(Operation *top) override; 78*4b3f251bSdonald chen LogicalResult visit(ProgramPoint *point) override; 79ead75d94SMogball 80ead75d94SMogball private: 81ead75d94SMogball void visitBlock(Block *block); 82ead75d94SMogball void visitOperation(Operation *op); 83ead75d94SMogball }; 84ead75d94SMogball 85ead75d94SMogball struct TestFooAnalysisPass 86ead75d94SMogball : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> { 87ead75d94SMogball MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass) 88ead75d94SMogball 89ead75d94SMogball StringRef getArgument() const override { return "test-foo-analysis"; } 90ead75d94SMogball 91ead75d94SMogball void runOnOperation() override; 92ead75d94SMogball }; 93ead75d94SMogball } // namespace 94ead75d94SMogball 95ead75d94SMogball LogicalResult FooAnalysis::initialize(Operation *top) { 96ead75d94SMogball if (top->getNumRegions() != 1) 97ead75d94SMogball return top->emitError("expected a single region top-level op"); 98ead75d94SMogball 99de58a4f1SKai Sasaki if (top->getRegion(0).getBlocks().empty()) 100de58a4f1SKai Sasaki return top->emitError("expected at least one block in the region"); 101de58a4f1SKai Sasaki 102ead75d94SMogball // Initialize the top-level state. 103*4b3f251bSdonald chen (void)getOrCreate<FooState>(getProgramPointBefore(&top->getRegion(0).front())) 104*4b3f251bSdonald chen ->join(0); 105ead75d94SMogball 106ead75d94SMogball // Visit all nested blocks and operations. 107ead75d94SMogball for (Block &block : top->getRegion(0)) { 108ead75d94SMogball visitBlock(&block); 109ead75d94SMogball for (Operation &op : block) { 110ead75d94SMogball if (op.getNumRegions()) 111ead75d94SMogball return op.emitError("unexpected op with regions"); 112ead75d94SMogball visitOperation(&op); 113ead75d94SMogball } 114ead75d94SMogball } 115ead75d94SMogball return success(); 116ead75d94SMogball } 117ead75d94SMogball 118*4b3f251bSdonald chen LogicalResult FooAnalysis::visit(ProgramPoint *point) { 119*4b3f251bSdonald chen if (!point->isBlockStart()) 120*4b3f251bSdonald chen visitOperation(point->getPrevOp()); 121b6603e1bSdonald chen else 122*4b3f251bSdonald chen visitBlock(point->getBlock()); 123ead75d94SMogball return success(); 124ead75d94SMogball } 125ead75d94SMogball 126ead75d94SMogball void FooAnalysis::visitBlock(Block *block) { 127ead75d94SMogball if (block->isEntryBlock()) { 128ead75d94SMogball // This is the initial state. Let the framework default-initialize it. 129ead75d94SMogball return; 130ead75d94SMogball } 131*4b3f251bSdonald chen ProgramPoint *point = getProgramPointBefore(block); 132*4b3f251bSdonald chen FooState *state = getOrCreate<FooState>(point); 133ead75d94SMogball ChangeResult result = ChangeResult::NoChange; 134ead75d94SMogball for (Block *pred : block->getPredecessors()) { 135ead75d94SMogball // Join the state at the terminators of all predecessors. 136*4b3f251bSdonald chen const FooState *predState = getOrCreateFor<FooState>( 137*4b3f251bSdonald chen point, getProgramPointAfter(pred->getTerminator())); 138ead75d94SMogball result |= state->join(*predState); 139ead75d94SMogball } 140ead75d94SMogball propagateIfChanged(state, result); 141ead75d94SMogball } 142ead75d94SMogball 143ead75d94SMogball void FooAnalysis::visitOperation(Operation *op) { 144*4b3f251bSdonald chen ProgramPoint *point = getProgramPointAfter(op); 145*4b3f251bSdonald chen FooState *state = getOrCreate<FooState>(point); 146ead75d94SMogball ChangeResult result = ChangeResult::NoChange; 147ead75d94SMogball 148ead75d94SMogball // Copy the state across the operation. 149ead75d94SMogball const FooState *prevState; 150*4b3f251bSdonald chen prevState = getOrCreateFor<FooState>(point, getProgramPointBefore(op)); 151ead75d94SMogball result |= state->set(*prevState); 152ead75d94SMogball 153ead75d94SMogball // Modify the state with the attribute, if specified. 154ead75d94SMogball if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) { 155ead75d94SMogball uint64_t value = attr.getUInt(); 156ead75d94SMogball result |= state->join(value); 157ead75d94SMogball } 158ead75d94SMogball propagateIfChanged(state, result); 159ead75d94SMogball } 160ead75d94SMogball 161ead75d94SMogball void TestFooAnalysisPass::runOnOperation() { 162ead75d94SMogball func::FuncOp func = getOperation(); 163ead75d94SMogball DataFlowSolver solver; 164ead75d94SMogball solver.load<FooAnalysis>(); 165ead75d94SMogball if (failed(solver.initializeAndRun(func))) 166ead75d94SMogball return signalPassFailure(); 167ead75d94SMogball 168ead75d94SMogball raw_ostream &os = llvm::errs(); 169ead75d94SMogball os << "function: @" << func.getSymName() << "\n"; 170ead75d94SMogball 171ead75d94SMogball func.walk([&](Operation *op) { 172ead75d94SMogball auto tag = op->getAttrOfType<StringAttr>("tag"); 173ead75d94SMogball if (!tag) 174ead75d94SMogball return; 175*4b3f251bSdonald chen const FooState *state = 176*4b3f251bSdonald chen solver.lookupState<FooState>(solver.getProgramPointAfter(op)); 177ead75d94SMogball assert(state && !state->isUninitialized()); 178ead75d94SMogball os << tag.getValue() << " -> " << state->getValue() << "\n"; 179ead75d94SMogball }); 180ead75d94SMogball } 181ead75d94SMogball 182ead75d94SMogball namespace mlir { 183ead75d94SMogball namespace test { 184ead75d94SMogball void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); } 185ead75d94SMogball } // namespace test 186ead75d94SMogball } // namespace mlir 187