1 //===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements sparse data-flow analysis using the data-flow analysis 10 // framework. The analysis is forward and conditional and uses the results of 11 // dead code analysis to prune dead code during the analysis. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H 16 #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H 17 18 #include "mlir/Analysis/DataFlowFramework.h" 19 #include "mlir/IR/SymbolTable.h" 20 #include "mlir/Interfaces/CallInterfaces.h" 21 #include "mlir/Interfaces/ControlFlowInterfaces.h" 22 #include "llvm/ADT/SmallPtrSet.h" 23 24 namespace mlir { 25 namespace dataflow { 26 27 //===----------------------------------------------------------------------===// 28 // AbstractSparseLattice 29 //===----------------------------------------------------------------------===// 30 31 /// This class represents an abstract lattice. A lattice contains information 32 /// about an SSA value and is what's propagated across the IR by sparse 33 /// data-flow analysis. 34 class AbstractSparseLattice : public AnalysisState { 35 public: 36 /// Lattices can only be created for values. 37 AbstractSparseLattice(Value value) : AnalysisState(value) {} 38 39 /// Return the value this lattice is located at. 40 Value getAnchor() const { return cast<Value>(AnalysisState::getAnchor()); } 41 42 /// Join the information contained in 'rhs' into this lattice. Returns 43 /// if the value of the lattice changed. 44 virtual ChangeResult join(const AbstractSparseLattice &rhs) { 45 return ChangeResult::NoChange; 46 } 47 48 /// Meet (intersect) the information in this lattice with 'rhs'. Returns 49 /// if the value of the lattice changed. 50 virtual ChangeResult meet(const AbstractSparseLattice &rhs) { 51 return ChangeResult::NoChange; 52 } 53 54 /// When the lattice gets updated, propagate an update to users of the value 55 /// using its use-def chain to subscribed analyses. 56 void onUpdate(DataFlowSolver *solver) const override; 57 58 /// Subscribe an analysis to updates of the lattice. When the lattice changes, 59 /// subscribed analyses are re-invoked on all users of the value. This is 60 /// more efficient than relying on the dependency map. 61 void useDefSubscribe(DataFlowAnalysis *analysis) { 62 useDefSubscribers.insert(analysis); 63 } 64 65 private: 66 /// A set of analyses that should be updated when this lattice changes. 67 SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>, 68 SmallPtrSet<DataFlowAnalysis *, 4>> 69 useDefSubscribers; 70 }; 71 72 //===----------------------------------------------------------------------===// 73 // Lattice 74 //===----------------------------------------------------------------------===// 75 76 /// This class represents a lattice holding a specific value of type `ValueT`. 77 /// Lattice values (`ValueT`) are required to adhere to the following: 78 /// 79 /// * static ValueT join(const ValueT &lhs, const ValueT &rhs); 80 /// - This method conservatively joins the information held by `lhs` 81 /// and `rhs` into a new value. This method is required to be monotonic. 82 /// * bool operator==(const ValueT &rhs) const; 83 /// 84 template <typename ValueT> 85 class Lattice : public AbstractSparseLattice { 86 public: 87 using AbstractSparseLattice::AbstractSparseLattice; 88 89 /// Return the value this lattice is located at. 90 Value getAnchor() const { return cast<Value>(anchor); } 91 92 /// Return the value held by this lattice. This requires that the value is 93 /// initialized. 94 ValueT &getValue() { return value; } 95 const ValueT &getValue() const { 96 return const_cast<Lattice<ValueT> *>(this)->getValue(); 97 } 98 99 using LatticeT = Lattice<ValueT>; 100 101 /// Join the information contained in the 'rhs' lattice into this 102 /// lattice. Returns if the state of the current lattice changed. 103 ChangeResult join(const AbstractSparseLattice &rhs) override { 104 return join(static_cast<const LatticeT &>(rhs).getValue()); 105 } 106 107 /// Meet (intersect) the information contained in the 'rhs' lattice with 108 /// this lattice. Returns if the state of the current lattice changed. 109 ChangeResult meet(const AbstractSparseLattice &rhs) override { 110 return meet(static_cast<const LatticeT &>(rhs).getValue()); 111 } 112 113 /// Join the information contained in the 'rhs' value into this 114 /// lattice. Returns if the state of the current lattice changed. 115 ChangeResult join(const ValueT &rhs) { 116 // Otherwise, join rhs with the current optimistic value. 117 ValueT newValue = ValueT::join(value, rhs); 118 assert(ValueT::join(newValue, value) == newValue && 119 "expected `join` to be monotonic"); 120 assert(ValueT::join(newValue, rhs) == newValue && 121 "expected `join` to be monotonic"); 122 123 // Update the current optimistic value if something changed. 124 if (newValue == value) 125 return ChangeResult::NoChange; 126 127 value = newValue; 128 return ChangeResult::Change; 129 } 130 131 /// Trait to check if `T` provides a `meet` method. Needed since for forward 132 /// analysis, lattices will only have a `join`, no `meet`, but we want to use 133 /// the same `Lattice` class for both directions. 134 template <typename T, typename... Args> 135 using has_meet = decltype(&T::meet); 136 template <typename T> 137 using lattice_has_meet = llvm::is_detected<has_meet, T>; 138 139 /// Meet (intersect) the information contained in the 'rhs' value with this 140 /// lattice. Returns if the state of the current lattice changed. If the 141 /// lattice elements don't have a `meet` method, this is a no-op (see below.) 142 template <typename VT, 143 std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr> 144 ChangeResult meet(const VT &rhs) { 145 ValueT newValue = ValueT::meet(value, rhs); 146 assert(ValueT::meet(newValue, value) == newValue && 147 "expected `meet` to be monotonic"); 148 assert(ValueT::meet(newValue, rhs) == newValue && 149 "expected `meet` to be monotonic"); 150 151 // Update the current optimistic value if something changed. 152 if (newValue == value) 153 return ChangeResult::NoChange; 154 155 value = newValue; 156 return ChangeResult::Change; 157 } 158 159 template <typename VT, 160 std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr> 161 ChangeResult meet(const VT &rhs) { 162 return ChangeResult::NoChange; 163 } 164 165 /// Print the lattice element. 166 void print(raw_ostream &os) const override { value.print(os); } 167 168 private: 169 /// The currently computed value that is optimistically assumed to be true. 170 ValueT value; 171 }; 172 173 //===----------------------------------------------------------------------===// 174 // AbstractSparseForwardDataFlowAnalysis 175 //===----------------------------------------------------------------------===// 176 177 /// Base class for sparse forward data-flow analyses. A sparse analysis 178 /// implements a transfer function on operations from the lattices of the 179 /// operands to the lattices of the results. This analysis will propagate 180 /// lattices across control-flow edges and the callgraph using liveness 181 /// information. 182 /// 183 /// Visit a program point in sparse forward data-flow analysis will invoke the 184 /// transfer function of the operation preceding the program point iterator. 185 /// Visit a program point at the begining of block will visit the block itself. 186 class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { 187 public: 188 /// Initialize the analysis by visiting every owner of an SSA value: all 189 /// operations and blocks. 190 LogicalResult initialize(Operation *top) override; 191 192 /// Visit a program point. If this is at beginning of block and all 193 /// control-flow predecessors or callsites are known, then the arguments 194 /// lattices are propagated from them. If this is after call operation or an 195 /// operation with region control-flow, then its result lattices are set 196 /// accordingly. Otherwise, the operation transfer function is invoked. 197 LogicalResult visit(ProgramPoint *point) override; 198 199 protected: 200 explicit AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver); 201 202 /// The operation transfer function. Given the operand lattices, this 203 /// function is expected to set the result lattices. 204 virtual LogicalResult 205 visitOperationImpl(Operation *op, 206 ArrayRef<const AbstractSparseLattice *> operandLattices, 207 ArrayRef<AbstractSparseLattice *> resultLattices) = 0; 208 209 /// The transfer function for calls to external functions. 210 virtual void visitExternalCallImpl( 211 CallOpInterface call, 212 ArrayRef<const AbstractSparseLattice *> argumentLattices, 213 ArrayRef<AbstractSparseLattice *> resultLattices) = 0; 214 215 /// Given an operation with region control-flow, the lattices of the operands, 216 /// and a region successor, compute the lattice values for block arguments 217 /// that are not accounted for by the branching control flow (ex. the bounds 218 /// of loops). 219 virtual void visitNonControlFlowArgumentsImpl( 220 Operation *op, const RegionSuccessor &successor, 221 ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0; 222 223 /// Get the lattice element of a value. 224 virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; 225 226 /// Get a read-only lattice element for a value and add it as a dependency to 227 /// a program point. 228 const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point, 229 Value value); 230 231 /// Set the given lattice element(s) at control flow entry point(s). 232 virtual void setToEntryState(AbstractSparseLattice *lattice) = 0; 233 void setAllToEntryStates(ArrayRef<AbstractSparseLattice *> lattices); 234 235 /// Join the lattice element and propagate and update if it changed. 236 void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); 237 238 private: 239 /// Recursively initialize the analysis on nested operations and blocks. 240 LogicalResult initializeRecursively(Operation *op); 241 242 /// Visit an operation. If this is a call operation or an operation with 243 /// region control-flow, then its result lattices are set accordingly. 244 /// Otherwise, the operation transfer function is invoked. 245 LogicalResult visitOperation(Operation *op); 246 247 /// Visit a block to compute the lattice values of its arguments. If this is 248 /// an entry block, then the argument values are determined from the block's 249 /// "predecessors" as set by `PredecessorState`. The predecessors can be 250 /// region terminators or callable callsites. Otherwise, the values are 251 /// determined from block predecessors. 252 void visitBlock(Block *block); 253 254 /// Visit a program point `point` with predecessors within a region branch 255 /// operation `branch`, which can either be the entry block of one of the 256 /// regions or the parent operation itself, and set either the argument or 257 /// parent result lattices. 258 void visitRegionSuccessors(ProgramPoint *point, 259 RegionBranchOpInterface branch, 260 RegionBranchPoint successor, 261 ArrayRef<AbstractSparseLattice *> lattices); 262 }; 263 264 //===----------------------------------------------------------------------===// 265 // SparseForwardDataFlowAnalysis 266 //===----------------------------------------------------------------------===// 267 268 /// A sparse forward data-flow analysis for propagating SSA value lattices 269 /// across the IR by implementing transfer functions for operations. 270 /// 271 /// `StateT` is expected to be a subclass of `AbstractSparseLattice`. 272 template <typename StateT> 273 class SparseForwardDataFlowAnalysis 274 : public AbstractSparseForwardDataFlowAnalysis { 275 static_assert( 276 std::is_base_of<AbstractSparseLattice, StateT>::value, 277 "analysis state class expected to subclass AbstractSparseLattice"); 278 279 public: 280 explicit SparseForwardDataFlowAnalysis(DataFlowSolver &solver) 281 : AbstractSparseForwardDataFlowAnalysis(solver) {} 282 283 /// Visit an operation with the lattices of its operands. This function is 284 /// expected to set the lattices of the operation's results. 285 virtual LogicalResult visitOperation(Operation *op, 286 ArrayRef<const StateT *> operands, 287 ArrayRef<StateT *> results) = 0; 288 289 /// Visit a call operation to an externally defined function given the 290 /// lattices of its arguments. 291 virtual void visitExternalCall(CallOpInterface call, 292 ArrayRef<const StateT *> argumentLattices, 293 ArrayRef<StateT *> resultLattices) { 294 setAllToEntryStates(resultLattices); 295 } 296 297 /// Given an operation with possible region control-flow, the lattices of the 298 /// operands, and a region successor, compute the lattice values for block 299 /// arguments that are not accounted for by the branching control flow (ex. 300 /// the bounds of loops). By default, this method marks all such lattice 301 /// elements as having reached a pessimistic fixpoint. `firstIndex` is the 302 /// index of the first element of `argLattices` that is set by control-flow. 303 virtual void visitNonControlFlowArguments(Operation *op, 304 const RegionSuccessor &successor, 305 ArrayRef<StateT *> argLattices, 306 unsigned firstIndex) { 307 setAllToEntryStates(argLattices.take_front(firstIndex)); 308 setAllToEntryStates(argLattices.drop_front( 309 firstIndex + successor.getSuccessorInputs().size())); 310 } 311 312 protected: 313 /// Get the lattice element for a value. 314 StateT *getLatticeElement(Value value) override { 315 return getOrCreate<StateT>(value); 316 } 317 318 /// Get the lattice element for a value and create a dependency on the 319 /// provided program point. 320 const StateT *getLatticeElementFor(ProgramPoint *point, Value value) { 321 return static_cast<const StateT *>( 322 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(point, 323 value)); 324 } 325 326 /// Set the given lattice element(s) at control flow entry point(s). 327 virtual void setToEntryState(StateT *lattice) = 0; 328 void setAllToEntryStates(ArrayRef<StateT *> lattices) { 329 AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates( 330 {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()), 331 lattices.size()}); 332 } 333 334 private: 335 /// Type-erased wrappers that convert the abstract lattice operands to derived 336 /// lattices and invoke the virtual hooks operating on the derived lattices. 337 LogicalResult visitOperationImpl( 338 Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices, 339 ArrayRef<AbstractSparseLattice *> resultLattices) override { 340 return visitOperation( 341 op, 342 {reinterpret_cast<const StateT *const *>(operandLattices.begin()), 343 operandLattices.size()}, 344 {reinterpret_cast<StateT *const *>(resultLattices.begin()), 345 resultLattices.size()}); 346 } 347 void visitExternalCallImpl( 348 CallOpInterface call, 349 ArrayRef<const AbstractSparseLattice *> argumentLattices, 350 ArrayRef<AbstractSparseLattice *> resultLattices) override { 351 visitExternalCall( 352 call, 353 {reinterpret_cast<const StateT *const *>(argumentLattices.begin()), 354 argumentLattices.size()}, 355 {reinterpret_cast<StateT *const *>(resultLattices.begin()), 356 resultLattices.size()}); 357 } 358 void visitNonControlFlowArgumentsImpl( 359 Operation *op, const RegionSuccessor &successor, 360 ArrayRef<AbstractSparseLattice *> argLattices, 361 unsigned firstIndex) override { 362 visitNonControlFlowArguments( 363 op, successor, 364 {reinterpret_cast<StateT *const *>(argLattices.begin()), 365 argLattices.size()}, 366 firstIndex); 367 } 368 void setToEntryState(AbstractSparseLattice *lattice) override { 369 return setToEntryState(reinterpret_cast<StateT *>(lattice)); 370 } 371 }; 372 373 //===----------------------------------------------------------------------===// 374 // AbstractSparseBackwardDataFlowAnalysis 375 //===----------------------------------------------------------------------===// 376 377 /// Base class for sparse backward data-flow analyses. Similar to 378 /// AbstractSparseForwardDataFlowAnalysis, but walks bottom to top. 379 class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis { 380 public: 381 /// Initialize the analysis by visiting the operation and everything nested 382 /// under it. 383 LogicalResult initialize(Operation *top) override; 384 385 /// Visit a program point. If it is after call operation or an operation with 386 /// block or region control-flow, then operand lattices are set accordingly. 387 /// Otherwise, invokes the operation transfer function (`visitOperationImpl`). 388 LogicalResult visit(ProgramPoint *point) override; 389 390 protected: 391 explicit AbstractSparseBackwardDataFlowAnalysis( 392 DataFlowSolver &solver, SymbolTableCollection &symbolTable); 393 394 /// The operation transfer function. Given the result lattices, this 395 /// function is expected to set the operand lattices. 396 virtual LogicalResult visitOperationImpl( 397 Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices, 398 ArrayRef<const AbstractSparseLattice *> resultLattices) = 0; 399 400 /// The transfer function for calls to external functions. 401 virtual void visitExternalCallImpl( 402 CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices, 403 ArrayRef<const AbstractSparseLattice *> resultLattices) = 0; 404 405 // Visit operands on branch instructions that are not forwarded. 406 virtual void visitBranchOperand(OpOperand &operand) = 0; 407 408 // Visit operands on call instructions that are not forwarded. 409 virtual void visitCallOperand(OpOperand &operand) = 0; 410 411 /// Set the given lattice element(s) at control flow exit point(s). 412 virtual void setToExitState(AbstractSparseLattice *lattice) = 0; 413 414 /// Set the given lattice element(s) at control flow exit point(s). 415 void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices); 416 417 /// Get the lattice element for a value. 418 virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; 419 420 /// Get the lattice elements for a range of values. 421 SmallVector<AbstractSparseLattice *> getLatticeElements(ValueRange values); 422 423 /// Join the lattice element and propagate and update if it changed. 424 void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); 425 426 private: 427 /// Recursively initialize the analysis on nested operations and blocks. 428 LogicalResult initializeRecursively(Operation *op); 429 430 /// Visit an operation. If this is a call operation or an operation with 431 /// region control-flow, then its operand lattices are set accordingly. 432 /// Otherwise, the operation transfer function is invoked. 433 LogicalResult visitOperation(Operation *op); 434 435 /// Visit a block. 436 void visitBlock(Block *block); 437 438 /// Visit an op with regions (like e.g. `scf.while`) 439 void visitRegionSuccessors(RegionBranchOpInterface branch, 440 ArrayRef<AbstractSparseLattice *> operands); 441 442 /// Visit a `RegionBranchTerminatorOpInterface` to compute the lattice values 443 /// of its operands, given its parent op `branch`. The lattice value of an 444 /// operand is determined based on the corresponding arguments in 445 /// `terminator`'s region successor(s). 446 void visitRegionSuccessorsFromTerminator( 447 RegionBranchTerminatorOpInterface terminator, 448 RegionBranchOpInterface branch); 449 450 /// Get the lattice element for a value, and also set up 451 /// dependencies so that the analysis on the given ProgramPoint is re-invoked 452 /// if the value changes. 453 const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point, 454 Value value); 455 456 /// Get the lattice elements for a range of values, and also set up 457 /// dependencies so that the analysis on the given ProgramPoint is re-invoked 458 /// if any of the values change. 459 SmallVector<const AbstractSparseLattice *> 460 getLatticeElementsFor(ProgramPoint *point, ValueRange values); 461 462 SymbolTableCollection &symbolTable; 463 }; 464 465 //===----------------------------------------------------------------------===// 466 // SparseBackwardDataFlowAnalysis 467 //===----------------------------------------------------------------------===// 468 469 /// A sparse (backward) data-flow analysis for propagating SSA value lattices 470 /// backwards across the IR by implementing transfer functions for operations. 471 /// 472 /// `StateT` is expected to be a subclass of `AbstractSparseLattice`. 473 /// 474 /// Visit a program point in sparse backward data-flow analysis will invoke the 475 /// transfer function of the operation preceding the program point iterator. 476 /// Visit a program point at the begining of block will visit the block itself. 477 template <typename StateT> 478 class SparseBackwardDataFlowAnalysis 479 : public AbstractSparseBackwardDataFlowAnalysis { 480 public: 481 explicit SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, 482 SymbolTableCollection &symbolTable) 483 : AbstractSparseBackwardDataFlowAnalysis(solver, symbolTable) {} 484 485 /// Visit an operation with the lattices of its results. This function is 486 /// expected to set the lattices of the operation's operands. 487 virtual LogicalResult visitOperation(Operation *op, 488 ArrayRef<StateT *> operands, 489 ArrayRef<const StateT *> results) = 0; 490 491 /// Visit a call to an external function. This function is expected to set 492 /// lattice values of the call operands. By default, calls `visitCallOperand` 493 /// for all operands. 494 virtual void visitExternalCall(CallOpInterface call, 495 ArrayRef<StateT *> argumentLattices, 496 ArrayRef<const StateT *> resultLattices) { 497 (void)argumentLattices; 498 (void)resultLattices; 499 for (OpOperand &operand : call->getOpOperands()) { 500 visitCallOperand(operand); 501 } 502 }; 503 504 protected: 505 /// Get the lattice element for a value. 506 StateT *getLatticeElement(Value value) override { 507 return getOrCreate<StateT>(value); 508 } 509 510 /// Set the given lattice element(s) at control flow exit point(s). 511 virtual void setToExitState(StateT *lattice) = 0; 512 void setToExitState(AbstractSparseLattice *lattice) override { 513 return setToExitState(reinterpret_cast<StateT *>(lattice)); 514 } 515 void setAllToExitStates(ArrayRef<StateT *> lattices) { 516 AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( 517 {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()), 518 lattices.size()}); 519 } 520 521 private: 522 /// Type-erased wrappers that convert the abstract lattice operands to derived 523 /// lattices and invoke the virtual hooks operating on the derived lattices. 524 LogicalResult visitOperationImpl( 525 Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices, 526 ArrayRef<const AbstractSparseLattice *> resultLattices) override { 527 return visitOperation( 528 op, 529 {reinterpret_cast<StateT *const *>(operandLattices.begin()), 530 operandLattices.size()}, 531 {reinterpret_cast<const StateT *const *>(resultLattices.begin()), 532 resultLattices.size()}); 533 } 534 535 void visitExternalCallImpl( 536 CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices, 537 ArrayRef<const AbstractSparseLattice *> resultLattices) override { 538 visitExternalCall( 539 call, 540 {reinterpret_cast<StateT *const *>(operandLattices.begin()), 541 operandLattices.size()}, 542 {reinterpret_cast<const StateT *const *>(resultLattices.begin()), 543 resultLattices.size()}); 544 } 545 }; 546 547 } // end namespace dataflow 548 } // end namespace mlir 549 550 #endif // MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H 551