1 //===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Test pass for backward dense dataflow analysis. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "TestDenseDataFlowAnalysis.h" 14 #include "TestDialect.h" 15 #include "TestOps.h" 16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 17 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 19 #include "mlir/Analysis/DataFlowFramework.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/SymbolTable.h" 22 #include "mlir/Interfaces/CallInterfaces.h" 23 #include "mlir/Interfaces/ControlFlowInterfaces.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/TypeID.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace mlir; 30 using namespace mlir::dataflow; 31 using namespace mlir::dataflow::test; 32 33 namespace { 34 35 class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { 36 public: 37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) 38 39 using dataflow::AbstractDenseLattice::AbstractDenseLattice; 40 41 ChangeResult meet(const AbstractDenseLattice &lattice) override { 42 return AccessLatticeBase::merge(static_cast<AccessLatticeBase>( 43 static_cast<const NextAccess &>(lattice))); 44 } 45 46 void print(raw_ostream &os) const override { 47 return AccessLatticeBase::print(os); 48 } 49 }; 50 51 class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> { 52 public: 53 NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, 54 bool assumeFuncReads = false) 55 : DenseBackwardDataFlowAnalysis(solver, symbolTable), 56 assumeFuncReads(assumeFuncReads) {} 57 58 LogicalResult visitOperation(Operation *op, const NextAccess &after, 59 NextAccess *before) override; 60 61 void visitCallControlFlowTransfer(CallOpInterface call, 62 CallControlFlowAction action, 63 const NextAccess &after, 64 NextAccess *before) override; 65 66 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, 67 RegionBranchPoint regionFrom, 68 RegionBranchPoint regionTo, 69 const NextAccess &after, 70 NextAccess *before) override; 71 72 // TODO: this isn't ideal for the analysis. When there is no next access, it 73 // means "we don't know what the next access is" rather than "there is no next 74 // access". But it's unclear how to differentiate the two cases... 75 void setToExitState(NextAccess *lattice) override { 76 propagateIfChanged(lattice, lattice->setKnownToUnknown()); 77 } 78 79 const bool assumeFuncReads; 80 }; 81 } // namespace 82 83 LogicalResult NextAccessAnalysis::visitOperation(Operation *op, 84 const NextAccess &after, 85 NextAccess *before) { 86 auto memory = dyn_cast<MemoryEffectOpInterface>(op); 87 // If we can't reason about the memory effects, conservatively assume we can't 88 // say anything about the next access. 89 if (!memory) { 90 setToExitState(before); 91 return success(); 92 } 93 94 SmallVector<MemoryEffects::EffectInstance> effects; 95 memory.getEffects(effects); 96 97 // First, check if all underlying values are already known. Otherwise, avoid 98 // propagating and stay in the "undefined" state to avoid incorrectly 99 // propagating values that may be overwritten later on as that could be 100 // problematic for convergence based on monotonicity of lattice updates. 101 SmallVector<Value> underlyingValues; 102 underlyingValues.reserve(effects.size()); 103 for (const MemoryEffects::EffectInstance &effect : effects) { 104 Value value = effect.getValue(); 105 106 // Effects with unspecified value are treated conservatively and we cannot 107 // assume anything about the next access. 108 if (!value) { 109 setToExitState(before); 110 return success(); 111 } 112 113 // If cannot find the most underlying value, we cannot assume anything about 114 // the next accesses. 115 std::optional<Value> underlyingValue = 116 UnderlyingValueAnalysis::getMostUnderlyingValue( 117 value, [&](Value value) { 118 return getOrCreateFor<UnderlyingValueLattice>( 119 getProgramPointBefore(op), value); 120 }); 121 122 // If the underlying value is not known yet, don't propagate. 123 if (!underlyingValue) 124 return success(); 125 126 underlyingValues.push_back(*underlyingValue); 127 } 128 129 // Update the state if all underlying values are known. 130 ChangeResult result = before->meet(after); 131 for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) { 132 // If the underlying value is known to be unknown, set to fixpoint. 133 if (!value) { 134 setToExitState(before); 135 return success(); 136 } 137 138 result |= before->set(value, op); 139 } 140 propagateIfChanged(before, result); 141 return success(); 142 } 143 144 void NextAccessAnalysis::visitCallControlFlowTransfer( 145 CallOpInterface call, CallControlFlowAction action, const NextAccess &after, 146 NextAccess *before) { 147 if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { 148 SmallVector<Value> underlyingValues; 149 underlyingValues.reserve(call->getNumOperands()); 150 for (Value operand : call.getArgOperands()) { 151 std::optional<Value> underlyingValue = 152 UnderlyingValueAnalysis::getMostUnderlyingValue( 153 operand, [&](Value value) { 154 return getOrCreateFor<UnderlyingValueLattice>( 155 getProgramPointBefore(call.getOperation()), value); 156 }); 157 if (!underlyingValue) 158 return; 159 underlyingValues.push_back(*underlyingValue); 160 } 161 162 ChangeResult result = before->meet(after); 163 for (Value operand : underlyingValues) { 164 result |= before->set(operand, call); 165 } 166 return propagateIfChanged(before, result); 167 } 168 auto testCallAndStore = 169 dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); 170 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && 171 testCallAndStore.getStoreBeforeCall()) || 172 (action == CallControlFlowAction::ExitCallee && 173 !testCallAndStore.getStoreBeforeCall()))) { 174 (void)visitOperation(call, after, before); 175 } else { 176 AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( 177 call, action, after, before); 178 } 179 } 180 181 void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( 182 RegionBranchOpInterface branch, RegionBranchPoint regionFrom, 183 RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { 184 auto testStoreWithARegion = 185 dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); 186 187 if (testStoreWithARegion && 188 ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || 189 (regionFrom.isParent() && 190 testStoreWithARegion.getStoreBeforeRegion()))) { 191 (void)visitOperation(branch, static_cast<const NextAccess &>(after), 192 static_cast<NextAccess *>(before)); 193 } else { 194 propagateIfChanged(before, before->meet(after)); 195 } 196 } 197 198 namespace { 199 struct TestNextAccessPass 200 : public PassWrapper<TestNextAccessPass, OperationPass<>> { 201 TestNextAccessPass() = default; 202 TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { 203 interprocedural = other.interprocedural; 204 assumeFuncReads = other.assumeFuncReads; 205 } 206 207 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) 208 209 StringRef getArgument() const override { return "test-next-access"; } 210 211 Option<bool> interprocedural{ 212 *this, "interprocedural", llvm::cl::init(true), 213 llvm::cl::desc("perform interprocedural analysis")}; 214 Option<bool> assumeFuncReads{ 215 *this, "assume-func-reads", llvm::cl::init(false), 216 llvm::cl::desc( 217 "assume external functions have read effect on all arguments")}; 218 219 static constexpr llvm::StringLiteral kTagAttrName = "name"; 220 static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; 221 static constexpr llvm::StringLiteral kAtEntryPointAttrName = 222 "next_at_entry_point"; 223 224 static Attribute makeNextAccessAttribute(Operation *op, 225 const DataFlowSolver &solver, 226 const NextAccess *nextAccess) { 227 if (!nextAccess) 228 return StringAttr::get(op->getContext(), "not computed"); 229 230 // Note that if the underlying value could not be computed or is unknown, we 231 // conservatively treat the result also unknown. 232 SmallVector<Attribute> attrs; 233 for (Value operand : op->getOperands()) { 234 std::optional<Value> underlyingValue = 235 UnderlyingValueAnalysis::getMostUnderlyingValue( 236 operand, [&](Value value) { 237 return solver.lookupState<UnderlyingValueLattice>(value); 238 }); 239 if (!underlyingValue) { 240 attrs.push_back(StringAttr::get(op->getContext(), "unknown")); 241 continue; 242 } 243 Value value = *underlyingValue; 244 const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); 245 if (!nextAcc || !nextAcc->isKnown()) { 246 attrs.push_back(StringAttr::get(op->getContext(), "unknown")); 247 continue; 248 } 249 250 SmallVector<Attribute> innerAttrs; 251 innerAttrs.reserve(nextAcc->get().size()); 252 for (Operation *nextAccOp : nextAcc->get()) { 253 if (auto nextAccTag = 254 nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) { 255 innerAttrs.push_back(nextAccTag); 256 continue; 257 } 258 std::string repr; 259 llvm::raw_string_ostream os(repr); 260 nextAccOp->print(os); 261 innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); 262 } 263 attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); 264 } 265 return ArrayAttr::get(op->getContext(), attrs); 266 } 267 268 void runOnOperation() override { 269 Operation *op = getOperation(); 270 SymbolTableCollection symbolTable; 271 272 auto config = DataFlowConfig().setInterprocedural(interprocedural); 273 DataFlowSolver solver(config); 274 solver.load<DeadCodeAnalysis>(); 275 solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads); 276 solver.load<SparseConstantPropagation>(); 277 solver.load<UnderlyingValueAnalysis>(); 278 if (failed(solver.initializeAndRun(op))) { 279 emitError(op->getLoc(), "dataflow solver failed"); 280 return signalPassFailure(); 281 } 282 op->walk([&](Operation *op) { 283 auto tag = op->getAttrOfType<StringAttr>(kTagAttrName); 284 if (!tag) 285 return; 286 287 const NextAccess *nextAccess = 288 solver.lookupState<NextAccess>(solver.getProgramPointAfter(op)); 289 op->setAttr(kNextAccessAttrName, 290 makeNextAccessAttribute(op, solver, nextAccess)); 291 292 auto iface = dyn_cast<RegionBranchOpInterface>(op); 293 if (!iface) 294 return; 295 296 SmallVector<Attribute> entryPointNextAccess; 297 SmallVector<RegionSuccessor> regionSuccessors; 298 iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); 299 for (const RegionSuccessor &successor : regionSuccessors) { 300 if (!successor.getSuccessor() || successor.getSuccessor()->empty()) 301 continue; 302 Block &successorBlock = successor.getSuccessor()->front(); 303 ProgramPoint *successorPoint = 304 solver.getProgramPointBefore(&successorBlock); 305 entryPointNextAccess.push_back(makeNextAccessAttribute( 306 op, solver, solver.lookupState<NextAccess>(successorPoint))); 307 } 308 op->setAttr(kAtEntryPointAttrName, 309 ArrayAttr::get(op->getContext(), entryPointNextAccess)); 310 }); 311 } 312 }; 313 } // namespace 314 315 namespace mlir::test { 316 void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); } 317 } // namespace mlir::test 318