1 //===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 11 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Interfaces/SideEffectInterfaces.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 using namespace mlir::dataflow; 18 19 namespace { 20 21 /// Lattice value storing the a set of memory resources that something 22 /// is written to. 23 struct WrittenToLatticeValue { 24 bool operator==(const WrittenToLatticeValue &other) { 25 return this->writes == other.writes; 26 } 27 28 static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs, 29 const WrittenToLatticeValue &rhs) { 30 WrittenToLatticeValue res = lhs; 31 (void)res.addWrites(rhs.writes); 32 33 return res; 34 } 35 36 static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs, 37 const WrittenToLatticeValue &rhs) { 38 // Should not be triggered by this test, but required by `Lattice<T>` 39 llvm_unreachable("Join should not be triggered by this test"); 40 } 41 42 ChangeResult addWrites(const SetVector<StringAttr> &writes) { 43 int sizeBefore = this->writes.size(); 44 this->writes.insert(writes.begin(), writes.end()); 45 int sizeAfter = this->writes.size(); 46 return sizeBefore == sizeAfter ? ChangeResult::NoChange 47 : ChangeResult::Change; 48 } 49 50 void print(raw_ostream &os) const { 51 os << "["; 52 llvm::interleave( 53 writes, os, [&](const StringAttr &a) { os << a.str(); }, " "); 54 os << "]"; 55 } 56 57 void clear() { writes.clear(); } 58 59 SetVector<StringAttr> writes; 60 }; 61 62 /// This lattice represents, for a given value, the set of memory resources that 63 /// this value, or anything derived from this value, is potentially written to. 64 struct WrittenTo : public Lattice<WrittenToLatticeValue> { 65 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo) 66 using Lattice::Lattice; 67 }; 68 69 /// An analysis that, by going backwards along the dataflow graph, annotates 70 /// each value with all the memory resources it (or anything derived from it) 71 /// is eventually written to. 72 class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> { 73 public: 74 WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, 75 bool assumeFuncWrites) 76 : SparseBackwardDataFlowAnalysis(solver, symbolTable), 77 assumeFuncWrites(assumeFuncWrites) {} 78 79 LogicalResult visitOperation(Operation *op, ArrayRef<WrittenTo *> operands, 80 ArrayRef<const WrittenTo *> results) override; 81 82 void visitBranchOperand(OpOperand &operand) override; 83 84 void visitCallOperand(OpOperand &operand) override; 85 86 void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands, 87 ArrayRef<const WrittenTo *> results) override; 88 89 void setToExitState(WrittenTo *lattice) override { 90 lattice->getValue().clear(); 91 } 92 93 private: 94 bool assumeFuncWrites; 95 }; 96 97 LogicalResult 98 WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands, 99 ArrayRef<const WrittenTo *> results) { 100 if (auto store = dyn_cast<memref::StoreOp>(op)) { 101 SetVector<StringAttr> newWrites; 102 newWrites.insert(op->getAttrOfType<StringAttr>("tag_name")); 103 propagateIfChanged(operands[0], 104 operands[0]->getValue().addWrites(newWrites)); 105 return success(); 106 } // By default, every result of an op depends on every operand. 107 for (const WrittenTo *r : results) { 108 for (WrittenTo *operand : operands) { 109 meet(operand, *r); 110 } 111 addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op)); 112 } 113 return success(); 114 } 115 116 void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) { 117 // Mark branch operands as "brancharg%d", with %d the operand number. 118 WrittenTo *lattice = getLatticeElement(operand.get()); 119 SetVector<StringAttr> newWrites; 120 newWrites.insert( 121 StringAttr::get(operand.getOwner()->getContext(), 122 "brancharg" + Twine(operand.getOperandNumber()))); 123 propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites)); 124 } 125 126 void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { 127 // Mark call operands as "callarg%d", with %d the operand number. 128 WrittenTo *lattice = getLatticeElement(operand.get()); 129 SetVector<StringAttr> newWrites; 130 newWrites.insert( 131 StringAttr::get(operand.getOwner()->getContext(), 132 "callarg" + Twine(operand.getOperandNumber()))); 133 propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites)); 134 } 135 136 void WrittenToAnalysis::visitExternalCall(CallOpInterface call, 137 ArrayRef<WrittenTo *> operands, 138 ArrayRef<const WrittenTo *> results) { 139 if (!assumeFuncWrites) { 140 return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands, 141 results); 142 } 143 144 for (WrittenTo *lattice : operands) { 145 SetVector<StringAttr> newWrites; 146 StringAttr name = call->getAttrOfType<StringAttr>("tag_name"); 147 if (!name) { 148 name = StringAttr::get(call->getContext(), 149 call.getOperation()->getName().getStringRef()); 150 } 151 newWrites.insert(name); 152 propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites)); 153 } 154 } 155 156 } // end anonymous namespace 157 158 namespace { 159 struct TestWrittenToPass 160 : public PassWrapper<TestWrittenToPass, OperationPass<>> { 161 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass) 162 163 TestWrittenToPass() = default; 164 TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) { 165 interprocedural = other.interprocedural; 166 assumeFuncWrites = other.assumeFuncWrites; 167 } 168 169 StringRef getArgument() const override { return "test-written-to"; } 170 171 Option<bool> interprocedural{ 172 *this, "interprocedural", llvm::cl::init(true), 173 llvm::cl::desc("perform interprocedural analysis")}; 174 Option<bool> assumeFuncWrites{ 175 *this, "assume-func-writes", llvm::cl::init(false), 176 llvm::cl::desc( 177 "assume external functions have write effect on all arguments")}; 178 179 void runOnOperation() override { 180 Operation *op = getOperation(); 181 182 SymbolTableCollection symbolTable; 183 184 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); 185 solver.load<DeadCodeAnalysis>(); 186 solver.load<SparseConstantPropagation>(); 187 solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites); 188 if (failed(solver.initializeAndRun(op))) 189 return signalPassFailure(); 190 191 raw_ostream &os = llvm::outs(); 192 op->walk([&](Operation *op) { 193 auto tag = op->getAttrOfType<StringAttr>("tag"); 194 if (!tag) 195 return; 196 os << "test_tag: " << tag.getValue() << ":\n"; 197 for (auto [index, operand] : llvm::enumerate(op->getOperands())) { 198 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand); 199 assert(writtenTo && "expected a sparse lattice"); 200 os << " operand #" << index << ": "; 201 writtenTo->print(os); 202 os << "\n"; 203 } 204 for (auto [index, operand] : llvm::enumerate(op->getResults())) { 205 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand); 206 assert(writtenTo && "expected a sparse lattice"); 207 os << " result #" << index << ": "; 208 writtenTo->print(os); 209 os << "\n"; 210 } 211 }); 212 } 213 }; 214 } // end anonymous namespace 215 216 namespace mlir { 217 namespace test { 218 void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); } 219 } // end namespace test 220 } // end namespace mlir 221