xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1 //===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 using namespace mlir::dataflow;
18 
19 namespace {
20 
21 /// Lattice value storing the a set of memory resources that something
22 /// is written to.
23 struct WrittenToLatticeValue {
24   bool operator==(const WrittenToLatticeValue &other) {
25     return this->writes == other.writes;
26   }
27 
28   static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
29                                     const WrittenToLatticeValue &rhs) {
30     WrittenToLatticeValue res = lhs;
31     (void)res.addWrites(rhs.writes);
32 
33     return res;
34   }
35 
36   static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
37                                     const WrittenToLatticeValue &rhs) {
38     // Should not be triggered by this test, but required by `Lattice<T>`
39     llvm_unreachable("Join should not be triggered by this test");
40   }
41 
42   ChangeResult addWrites(const SetVector<StringAttr> &writes) {
43     int sizeBefore = this->writes.size();
44     this->writes.insert(writes.begin(), writes.end());
45     int sizeAfter = this->writes.size();
46     return sizeBefore == sizeAfter ? ChangeResult::NoChange
47                                    : ChangeResult::Change;
48   }
49 
50   void print(raw_ostream &os) const {
51     os << "[";
52     llvm::interleave(
53         writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
54     os << "]";
55   }
56 
57   void clear() { writes.clear(); }
58 
59   SetVector<StringAttr> writes;
60 };
61 
62 /// This lattice represents, for a given value, the set of memory resources that
63 /// this value, or anything derived from this value, is potentially written to.
64 struct WrittenTo : public Lattice<WrittenToLatticeValue> {
65   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
66   using Lattice::Lattice;
67 };
68 
69 /// An analysis that, by going backwards along the dataflow graph, annotates
70 /// each value with all the memory resources it (or anything derived from it)
71 /// is eventually written to.
72 class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
73 public:
74   WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
75                     bool assumeFuncWrites)
76       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
77         assumeFuncWrites(assumeFuncWrites) {}
78 
79   LogicalResult visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
80                                ArrayRef<const WrittenTo *> results) override;
81 
82   void visitBranchOperand(OpOperand &operand) override;
83 
84   void visitCallOperand(OpOperand &operand) override;
85 
86   void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
87                          ArrayRef<const WrittenTo *> results) override;
88 
89   void setToExitState(WrittenTo *lattice) override {
90     lattice->getValue().clear();
91   }
92 
93 private:
94   bool assumeFuncWrites;
95 };
96 
97 LogicalResult
98 WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
99                                   ArrayRef<const WrittenTo *> results) {
100   if (auto store = dyn_cast<memref::StoreOp>(op)) {
101     SetVector<StringAttr> newWrites;
102     newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
103     propagateIfChanged(operands[0],
104                        operands[0]->getValue().addWrites(newWrites));
105     return success();
106   } // By default, every result of an op depends on every operand.
107   for (const WrittenTo *r : results) {
108     for (WrittenTo *operand : operands) {
109       meet(operand, *r);
110     }
111     addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op));
112   }
113   return success();
114 }
115 
116 void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
117   // Mark branch operands as "brancharg%d", with %d the operand number.
118   WrittenTo *lattice = getLatticeElement(operand.get());
119   SetVector<StringAttr> newWrites;
120   newWrites.insert(
121       StringAttr::get(operand.getOwner()->getContext(),
122                       "brancharg" + Twine(operand.getOperandNumber())));
123   propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
124 }
125 
126 void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
127   // Mark call operands as "callarg%d", with %d the operand number.
128   WrittenTo *lattice = getLatticeElement(operand.get());
129   SetVector<StringAttr> newWrites;
130   newWrites.insert(
131       StringAttr::get(operand.getOwner()->getContext(),
132                       "callarg" + Twine(operand.getOperandNumber())));
133   propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
134 }
135 
136 void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
137                                           ArrayRef<WrittenTo *> operands,
138                                           ArrayRef<const WrittenTo *> results) {
139   if (!assumeFuncWrites) {
140     return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
141                                                              results);
142   }
143 
144   for (WrittenTo *lattice : operands) {
145     SetVector<StringAttr> newWrites;
146     StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
147     if (!name) {
148       name = StringAttr::get(call->getContext(),
149                              call.getOperation()->getName().getStringRef());
150     }
151     newWrites.insert(name);
152     propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
153   }
154 }
155 
156 } // end anonymous namespace
157 
158 namespace {
159 struct TestWrittenToPass
160     : public PassWrapper<TestWrittenToPass, OperationPass<>> {
161   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
162 
163   TestWrittenToPass() = default;
164   TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
165     interprocedural = other.interprocedural;
166     assumeFuncWrites = other.assumeFuncWrites;
167   }
168 
169   StringRef getArgument() const override { return "test-written-to"; }
170 
171   Option<bool> interprocedural{
172       *this, "interprocedural", llvm::cl::init(true),
173       llvm::cl::desc("perform interprocedural analysis")};
174   Option<bool> assumeFuncWrites{
175       *this, "assume-func-writes", llvm::cl::init(false),
176       llvm::cl::desc(
177           "assume external functions have write effect on all arguments")};
178 
179   void runOnOperation() override {
180     Operation *op = getOperation();
181 
182     SymbolTableCollection symbolTable;
183 
184     DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
185     solver.load<DeadCodeAnalysis>();
186     solver.load<SparseConstantPropagation>();
187     solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
188     if (failed(solver.initializeAndRun(op)))
189       return signalPassFailure();
190 
191     raw_ostream &os = llvm::outs();
192     op->walk([&](Operation *op) {
193       auto tag = op->getAttrOfType<StringAttr>("tag");
194       if (!tag)
195         return;
196       os << "test_tag: " << tag.getValue() << ":\n";
197       for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
198         const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
199         assert(writtenTo && "expected a sparse lattice");
200         os << " operand #" << index << ": ";
201         writtenTo->print(os);
202         os << "\n";
203       }
204       for (auto [index, operand] : llvm::enumerate(op->getResults())) {
205         const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
206         assert(writtenTo && "expected a sparse lattice");
207         os << " result #" << index << ": ";
208         writtenTo->print(os);
209         os << "\n";
210       }
211     });
212   }
213 };
214 } // end anonymous namespace
215 
216 namespace mlir {
217 namespace test {
218 void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
219 } // end namespace test
220 } // end namespace mlir
221