xref: /llvm-project/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements sparse data-flow analysis using the data-flow analysis
10 // framework. The analysis is forward and conditional and uses the results of
11 // dead code analysis to prune dead code during the analysis.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
16 #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
17 
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/Interfaces/CallInterfaces.h"
21 #include "mlir/Interfaces/ControlFlowInterfaces.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 
24 namespace mlir {
25 namespace dataflow {
26 
27 //===----------------------------------------------------------------------===//
28 // AbstractSparseLattice
29 //===----------------------------------------------------------------------===//
30 
31 /// This class represents an abstract lattice. A lattice contains information
32 /// about an SSA value and is what's propagated across the IR by sparse
33 /// data-flow analysis.
34 class AbstractSparseLattice : public AnalysisState {
35 public:
36   /// Lattices can only be created for values.
37   AbstractSparseLattice(Value value) : AnalysisState(value) {}
38 
39   /// Return the value this lattice is located at.
40   Value getAnchor() const { return cast<Value>(AnalysisState::getAnchor()); }
41 
42   /// Join the information contained in 'rhs' into this lattice. Returns
43   /// if the value of the lattice changed.
44   virtual ChangeResult join(const AbstractSparseLattice &rhs) {
45     return ChangeResult::NoChange;
46   }
47 
48   /// Meet (intersect) the information in this lattice with 'rhs'. Returns
49   /// if the value of the lattice changed.
50   virtual ChangeResult meet(const AbstractSparseLattice &rhs) {
51     return ChangeResult::NoChange;
52   }
53 
54   /// When the lattice gets updated, propagate an update to users of the value
55   /// using its use-def chain to subscribed analyses.
56   void onUpdate(DataFlowSolver *solver) const override;
57 
58   /// Subscribe an analysis to updates of the lattice. When the lattice changes,
59   /// subscribed analyses are re-invoked on all users of the value. This is
60   /// more efficient than relying on the dependency map.
61   void useDefSubscribe(DataFlowAnalysis *analysis) {
62     useDefSubscribers.insert(analysis);
63   }
64 
65 private:
66   /// A set of analyses that should be updated when this lattice changes.
67   SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
68             SmallPtrSet<DataFlowAnalysis *, 4>>
69       useDefSubscribers;
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // Lattice
74 //===----------------------------------------------------------------------===//
75 
76 /// This class represents a lattice holding a specific value of type `ValueT`.
77 /// Lattice values (`ValueT`) are required to adhere to the following:
78 ///
79 ///   * static ValueT join(const ValueT &lhs, const ValueT &rhs);
80 ///     - This method conservatively joins the information held by `lhs`
81 ///       and `rhs` into a new value. This method is required to be monotonic.
82 ///   * bool operator==(const ValueT &rhs) const;
83 ///
84 template <typename ValueT>
85 class Lattice : public AbstractSparseLattice {
86 public:
87   using AbstractSparseLattice::AbstractSparseLattice;
88 
89   /// Return the value this lattice is located at.
90   Value getAnchor() const { return cast<Value>(anchor); }
91 
92   /// Return the value held by this lattice. This requires that the value is
93   /// initialized.
94   ValueT &getValue() { return value; }
95   const ValueT &getValue() const {
96     return const_cast<Lattice<ValueT> *>(this)->getValue();
97   }
98 
99   using LatticeT = Lattice<ValueT>;
100 
101   /// Join the information contained in the 'rhs' lattice into this
102   /// lattice. Returns if the state of the current lattice changed.
103   ChangeResult join(const AbstractSparseLattice &rhs) override {
104     return join(static_cast<const LatticeT &>(rhs).getValue());
105   }
106 
107   /// Meet (intersect) the information contained in the 'rhs' lattice with
108   /// this lattice. Returns if the state of the current lattice changed.
109   ChangeResult meet(const AbstractSparseLattice &rhs) override {
110     return meet(static_cast<const LatticeT &>(rhs).getValue());
111   }
112 
113   /// Join the information contained in the 'rhs' value into this
114   /// lattice. Returns if the state of the current lattice changed.
115   ChangeResult join(const ValueT &rhs) {
116     // Otherwise, join rhs with the current optimistic value.
117     ValueT newValue = ValueT::join(value, rhs);
118     assert(ValueT::join(newValue, value) == newValue &&
119            "expected `join` to be monotonic");
120     assert(ValueT::join(newValue, rhs) == newValue &&
121            "expected `join` to be monotonic");
122 
123     // Update the current optimistic value if something changed.
124     if (newValue == value)
125       return ChangeResult::NoChange;
126 
127     value = newValue;
128     return ChangeResult::Change;
129   }
130 
131   /// Trait to check if `T` provides a `meet` method. Needed since for forward
132   /// analysis, lattices will only have a `join`, no `meet`, but we want to use
133   /// the same `Lattice` class for both directions.
134   template <typename T, typename... Args>
135   using has_meet = decltype(&T::meet);
136   template <typename T>
137   using lattice_has_meet = llvm::is_detected<has_meet, T>;
138 
139   /// Meet (intersect) the information contained in the 'rhs' value with this
140   /// lattice. Returns if the state of the current lattice changed.  If the
141   /// lattice elements don't have a `meet` method, this is a no-op (see below.)
142   template <typename VT,
143             std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
144   ChangeResult meet(const VT &rhs) {
145     ValueT newValue = ValueT::meet(value, rhs);
146     assert(ValueT::meet(newValue, value) == newValue &&
147            "expected `meet` to be monotonic");
148     assert(ValueT::meet(newValue, rhs) == newValue &&
149            "expected `meet` to be monotonic");
150 
151     // Update the current optimistic value if something changed.
152     if (newValue == value)
153       return ChangeResult::NoChange;
154 
155     value = newValue;
156     return ChangeResult::Change;
157   }
158 
159   template <typename VT,
160             std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
161   ChangeResult meet(const VT &rhs) {
162     return ChangeResult::NoChange;
163   }
164 
165   /// Print the lattice element.
166   void print(raw_ostream &os) const override { value.print(os); }
167 
168 private:
169   /// The currently computed value that is optimistically assumed to be true.
170   ValueT value;
171 };
172 
173 //===----------------------------------------------------------------------===//
174 // AbstractSparseForwardDataFlowAnalysis
175 //===----------------------------------------------------------------------===//
176 
177 /// Base class for sparse forward data-flow analyses. A sparse analysis
178 /// implements a transfer function on operations from the lattices of the
179 /// operands to the lattices of the results. This analysis will propagate
180 /// lattices across control-flow edges and the callgraph using liveness
181 /// information.
182 ///
183 /// Visit a program point in sparse forward data-flow analysis will invoke the
184 /// transfer function of the operation preceding the program point iterator.
185 /// Visit a program point at the begining of block will visit the block itself.
186 class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
187 public:
188   /// Initialize the analysis by visiting every owner of an SSA value: all
189   /// operations and blocks.
190   LogicalResult initialize(Operation *top) override;
191 
192   /// Visit a program point. If this is at beginning of block and all
193   /// control-flow predecessors or callsites are known, then the arguments
194   /// lattices are propagated from them. If this is after call operation or an
195   /// operation with region control-flow, then its result lattices are set
196   /// accordingly.  Otherwise, the operation transfer function is invoked.
197   LogicalResult visit(ProgramPoint *point) override;
198 
199 protected:
200   explicit AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver);
201 
202   /// The operation transfer function. Given the operand lattices, this
203   /// function is expected to set the result lattices.
204   virtual LogicalResult
205   visitOperationImpl(Operation *op,
206                      ArrayRef<const AbstractSparseLattice *> operandLattices,
207                      ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
208 
209   /// The transfer function for calls to external functions.
210   virtual void visitExternalCallImpl(
211       CallOpInterface call,
212       ArrayRef<const AbstractSparseLattice *> argumentLattices,
213       ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
214 
215   /// Given an operation with region control-flow, the lattices of the operands,
216   /// and a region successor, compute the lattice values for block arguments
217   /// that are not accounted for by the branching control flow (ex. the bounds
218   /// of loops).
219   virtual void visitNonControlFlowArgumentsImpl(
220       Operation *op, const RegionSuccessor &successor,
221       ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
222 
223   /// Get the lattice element of a value.
224   virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
225 
226   /// Get a read-only lattice element for a value and add it as a dependency to
227   /// a program point.
228   const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point,
229                                                     Value value);
230 
231   /// Set the given lattice element(s) at control flow entry point(s).
232   virtual void setToEntryState(AbstractSparseLattice *lattice) = 0;
233   void setAllToEntryStates(ArrayRef<AbstractSparseLattice *> lattices);
234 
235   /// Join the lattice element and propagate and update if it changed.
236   void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
237 
238 private:
239   /// Recursively initialize the analysis on nested operations and blocks.
240   LogicalResult initializeRecursively(Operation *op);
241 
242   /// Visit an operation. If this is a call operation or an operation with
243   /// region control-flow, then its result lattices are set accordingly.
244   /// Otherwise, the operation transfer function is invoked.
245   LogicalResult visitOperation(Operation *op);
246 
247   /// Visit a block to compute the lattice values of its arguments. If this is
248   /// an entry block, then the argument values are determined from the block's
249   /// "predecessors" as set by `PredecessorState`. The predecessors can be
250   /// region terminators or callable callsites. Otherwise, the values are
251   /// determined from block predecessors.
252   void visitBlock(Block *block);
253 
254   /// Visit a program point `point` with predecessors within a region branch
255   /// operation `branch`, which can either be the entry block of one of the
256   /// regions or the parent operation itself, and set either the argument or
257   /// parent result lattices.
258   void visitRegionSuccessors(ProgramPoint *point,
259                              RegionBranchOpInterface branch,
260                              RegionBranchPoint successor,
261                              ArrayRef<AbstractSparseLattice *> lattices);
262 };
263 
264 //===----------------------------------------------------------------------===//
265 // SparseForwardDataFlowAnalysis
266 //===----------------------------------------------------------------------===//
267 
268 /// A sparse forward data-flow analysis for propagating SSA value lattices
269 /// across the IR by implementing transfer functions for operations.
270 ///
271 /// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
272 template <typename StateT>
273 class SparseForwardDataFlowAnalysis
274     : public AbstractSparseForwardDataFlowAnalysis {
275   static_assert(
276       std::is_base_of<AbstractSparseLattice, StateT>::value,
277       "analysis state class expected to subclass AbstractSparseLattice");
278 
279 public:
280   explicit SparseForwardDataFlowAnalysis(DataFlowSolver &solver)
281       : AbstractSparseForwardDataFlowAnalysis(solver) {}
282 
283   /// Visit an operation with the lattices of its operands. This function is
284   /// expected to set the lattices of the operation's results.
285   virtual LogicalResult visitOperation(Operation *op,
286                                        ArrayRef<const StateT *> operands,
287                                        ArrayRef<StateT *> results) = 0;
288 
289   /// Visit a call operation to an externally defined function given the
290   /// lattices of its arguments.
291   virtual void visitExternalCall(CallOpInterface call,
292                                  ArrayRef<const StateT *> argumentLattices,
293                                  ArrayRef<StateT *> resultLattices) {
294     setAllToEntryStates(resultLattices);
295   }
296 
297   /// Given an operation with possible region control-flow, the lattices of the
298   /// operands, and a region successor, compute the lattice values for block
299   /// arguments that are not accounted for by the branching control flow (ex.
300   /// the bounds of loops). By default, this method marks all such lattice
301   /// elements as having reached a pessimistic fixpoint. `firstIndex` is the
302   /// index of the first element of `argLattices` that is set by control-flow.
303   virtual void visitNonControlFlowArguments(Operation *op,
304                                             const RegionSuccessor &successor,
305                                             ArrayRef<StateT *> argLattices,
306                                             unsigned firstIndex) {
307     setAllToEntryStates(argLattices.take_front(firstIndex));
308     setAllToEntryStates(argLattices.drop_front(
309         firstIndex + successor.getSuccessorInputs().size()));
310   }
311 
312 protected:
313   /// Get the lattice element for a value.
314   StateT *getLatticeElement(Value value) override {
315     return getOrCreate<StateT>(value);
316   }
317 
318   /// Get the lattice element for a value and create a dependency on the
319   /// provided program point.
320   const StateT *getLatticeElementFor(ProgramPoint *point, Value value) {
321     return static_cast<const StateT *>(
322         AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(point,
323                                                                     value));
324   }
325 
326   /// Set the given lattice element(s) at control flow entry point(s).
327   virtual void setToEntryState(StateT *lattice) = 0;
328   void setAllToEntryStates(ArrayRef<StateT *> lattices) {
329     AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
330         {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
331          lattices.size()});
332   }
333 
334 private:
335   /// Type-erased wrappers that convert the abstract lattice operands to derived
336   /// lattices and invoke the virtual hooks operating on the derived lattices.
337   LogicalResult visitOperationImpl(
338       Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
339       ArrayRef<AbstractSparseLattice *> resultLattices) override {
340     return visitOperation(
341         op,
342         {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
343          operandLattices.size()},
344         {reinterpret_cast<StateT *const *>(resultLattices.begin()),
345          resultLattices.size()});
346   }
347   void visitExternalCallImpl(
348       CallOpInterface call,
349       ArrayRef<const AbstractSparseLattice *> argumentLattices,
350       ArrayRef<AbstractSparseLattice *> resultLattices) override {
351     visitExternalCall(
352         call,
353         {reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
354          argumentLattices.size()},
355         {reinterpret_cast<StateT *const *>(resultLattices.begin()),
356          resultLattices.size()});
357   }
358   void visitNonControlFlowArgumentsImpl(
359       Operation *op, const RegionSuccessor &successor,
360       ArrayRef<AbstractSparseLattice *> argLattices,
361       unsigned firstIndex) override {
362     visitNonControlFlowArguments(
363         op, successor,
364         {reinterpret_cast<StateT *const *>(argLattices.begin()),
365          argLattices.size()},
366         firstIndex);
367   }
368   void setToEntryState(AbstractSparseLattice *lattice) override {
369     return setToEntryState(reinterpret_cast<StateT *>(lattice));
370   }
371 };
372 
373 //===----------------------------------------------------------------------===//
374 // AbstractSparseBackwardDataFlowAnalysis
375 //===----------------------------------------------------------------------===//
376 
377 /// Base class for sparse backward data-flow analyses. Similar to
378 /// AbstractSparseForwardDataFlowAnalysis, but walks bottom to top.
379 class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
380 public:
381   /// Initialize the analysis by visiting the operation and everything nested
382   /// under it.
383   LogicalResult initialize(Operation *top) override;
384 
385   /// Visit a program point. If it is after call operation or an operation with
386   /// block or region control-flow, then operand lattices are set accordingly.
387   /// Otherwise, invokes the operation transfer function (`visitOperationImpl`).
388   LogicalResult visit(ProgramPoint *point) override;
389 
390 protected:
391   explicit AbstractSparseBackwardDataFlowAnalysis(
392       DataFlowSolver &solver, SymbolTableCollection &symbolTable);
393 
394   /// The operation transfer function. Given the result lattices, this
395   /// function is expected to set the operand lattices.
396   virtual LogicalResult visitOperationImpl(
397       Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
398       ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
399 
400   /// The transfer function for calls to external functions.
401   virtual void visitExternalCallImpl(
402       CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
403       ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
404 
405   // Visit operands on branch instructions that are not forwarded.
406   virtual void visitBranchOperand(OpOperand &operand) = 0;
407 
408   // Visit operands on call instructions that are not forwarded.
409   virtual void visitCallOperand(OpOperand &operand) = 0;
410 
411   /// Set the given lattice element(s) at control flow exit point(s).
412   virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
413 
414   /// Set the given lattice element(s) at control flow exit point(s).
415   void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
416 
417   /// Get the lattice element for a value.
418   virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
419 
420   /// Get the lattice elements for a range of values.
421   SmallVector<AbstractSparseLattice *> getLatticeElements(ValueRange values);
422 
423   /// Join the lattice element and propagate and update if it changed.
424   void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
425 
426 private:
427   /// Recursively initialize the analysis on nested operations and blocks.
428   LogicalResult initializeRecursively(Operation *op);
429 
430   /// Visit an operation. If this is a call operation or an operation with
431   /// region control-flow, then its operand lattices are set accordingly.
432   /// Otherwise, the operation transfer function is invoked.
433   LogicalResult visitOperation(Operation *op);
434 
435   /// Visit a block.
436   void visitBlock(Block *block);
437 
438   /// Visit an op with regions (like e.g. `scf.while`)
439   void visitRegionSuccessors(RegionBranchOpInterface branch,
440                              ArrayRef<AbstractSparseLattice *> operands);
441 
442   /// Visit a `RegionBranchTerminatorOpInterface` to compute the lattice values
443   /// of its operands, given its parent op `branch`. The lattice value of an
444   /// operand is determined based on the corresponding arguments in
445   /// `terminator`'s region successor(s).
446   void visitRegionSuccessorsFromTerminator(
447       RegionBranchTerminatorOpInterface terminator,
448       RegionBranchOpInterface branch);
449 
450   /// Get the lattice element for a value, and also set up
451   /// dependencies so that the analysis on the given ProgramPoint is re-invoked
452   /// if the value changes.
453   const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point,
454                                                     Value value);
455 
456   /// Get the lattice elements for a range of values, and also set up
457   /// dependencies so that the analysis on the given ProgramPoint is re-invoked
458   /// if any of the values change.
459   SmallVector<const AbstractSparseLattice *>
460   getLatticeElementsFor(ProgramPoint *point, ValueRange values);
461 
462   SymbolTableCollection &symbolTable;
463 };
464 
465 //===----------------------------------------------------------------------===//
466 // SparseBackwardDataFlowAnalysis
467 //===----------------------------------------------------------------------===//
468 
469 /// A sparse (backward) data-flow analysis for propagating SSA value lattices
470 /// backwards across the IR by implementing transfer functions for operations.
471 ///
472 /// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
473 ///
474 /// Visit a program point in sparse backward data-flow analysis will invoke the
475 /// transfer function of the operation preceding the program point iterator.
476 /// Visit a program point at the begining of block will visit the block itself.
477 template <typename StateT>
478 class SparseBackwardDataFlowAnalysis
479     : public AbstractSparseBackwardDataFlowAnalysis {
480 public:
481   explicit SparseBackwardDataFlowAnalysis(DataFlowSolver &solver,
482                                           SymbolTableCollection &symbolTable)
483       : AbstractSparseBackwardDataFlowAnalysis(solver, symbolTable) {}
484 
485   /// Visit an operation with the lattices of its results. This function is
486   /// expected to set the lattices of the operation's operands.
487   virtual LogicalResult visitOperation(Operation *op,
488                                        ArrayRef<StateT *> operands,
489                                        ArrayRef<const StateT *> results) = 0;
490 
491   /// Visit a call to an external function. This function is expected to set
492   /// lattice values of the call operands. By default, calls `visitCallOperand`
493   /// for all operands.
494   virtual void visitExternalCall(CallOpInterface call,
495                                  ArrayRef<StateT *> argumentLattices,
496                                  ArrayRef<const StateT *> resultLattices) {
497     (void)argumentLattices;
498     (void)resultLattices;
499     for (OpOperand &operand : call->getOpOperands()) {
500       visitCallOperand(operand);
501     }
502   };
503 
504 protected:
505   /// Get the lattice element for a value.
506   StateT *getLatticeElement(Value value) override {
507     return getOrCreate<StateT>(value);
508   }
509 
510   /// Set the given lattice element(s) at control flow exit point(s).
511   virtual void setToExitState(StateT *lattice) = 0;
512   void setToExitState(AbstractSparseLattice *lattice) override {
513     return setToExitState(reinterpret_cast<StateT *>(lattice));
514   }
515   void setAllToExitStates(ArrayRef<StateT *> lattices) {
516     AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
517         {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
518          lattices.size()});
519   }
520 
521 private:
522   /// Type-erased wrappers that convert the abstract lattice operands to derived
523   /// lattices and invoke the virtual hooks operating on the derived lattices.
524   LogicalResult visitOperationImpl(
525       Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
526       ArrayRef<const AbstractSparseLattice *> resultLattices) override {
527     return visitOperation(
528         op,
529         {reinterpret_cast<StateT *const *>(operandLattices.begin()),
530          operandLattices.size()},
531         {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
532          resultLattices.size()});
533   }
534 
535   void visitExternalCallImpl(
536       CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
537       ArrayRef<const AbstractSparseLattice *> resultLattices) override {
538     visitExternalCall(
539         call,
540         {reinterpret_cast<StateT *const *>(operandLattices.begin()),
541          operandLattices.size()},
542         {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
543          resultLattices.size()});
544   }
545 };
546 
547 } // end namespace dataflow
548 } // end namespace mlir
549 
550 #endif // MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
551