xref: /llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (revision 4dd744ac9c0f772a61dd91c84bc14d17e69aec51)
1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlowFramework.h"
12 #include "mlir/Interfaces/CallInterfaces.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 //===----------------------------------------------------------------------===//
18 // AbstractSparseLattice
19 //===----------------------------------------------------------------------===//
20 
21 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
22   AnalysisState::onUpdate(solver);
23 
24   // Push all users of the value to the queue.
25   for (Operation *user : point.get<Value>().getUsers())
26     for (DataFlowAnalysis *analysis : useDefSubscribers)
27       solver->enqueue({user, analysis});
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // AbstractSparseForwardDataFlowAnalysis
32 //===----------------------------------------------------------------------===//
33 
34 AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
35     DataFlowSolver &solver)
36     : DataFlowAnalysis(solver) {
37   registerPointKind<CFGEdge>();
38 }
39 
40 LogicalResult
41 AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
42   // Mark the entry block arguments as having reached their pessimistic
43   // fixpoints.
44   for (Region &region : top->getRegions()) {
45     if (region.empty())
46       continue;
47     for (Value argument : region.front().getArguments())
48       setToEntryState(getLatticeElement(argument));
49   }
50 
51   return initializeRecursively(top);
52 }
53 
54 LogicalResult
55 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
56   // Initialize the analysis by visiting every owner of an SSA value (all
57   // operations and blocks).
58   visitOperation(op);
59   for (Region &region : op->getRegions()) {
60     for (Block &block : region) {
61       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
62       visitBlock(&block);
63       for (Operation &op : block)
64         if (failed(initializeRecursively(&op)))
65           return failure();
66     }
67   }
68 
69   return success();
70 }
71 
72 LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
73   if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
74     visitOperation(op);
75   else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
76     visitBlock(block);
77   else
78     return failure();
79   return success();
80 }
81 
82 void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
83   // Exit early on operations with no results.
84   if (op->getNumResults() == 0)
85     return;
86 
87   // If the containing block is not executable, bail out.
88   if (!getOrCreate<Executable>(op->getBlock())->isLive())
89     return;
90 
91   // Get the result lattices.
92   SmallVector<AbstractSparseLattice *> resultLattices;
93   resultLattices.reserve(op->getNumResults());
94   for (Value result : op->getResults()) {
95     AbstractSparseLattice *resultLattice = getLatticeElement(result);
96     resultLattices.push_back(resultLattice);
97   }
98 
99   // The results of a region branch operation are determined by control-flow.
100   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
101     return visitRegionSuccessors({branch}, branch,
102                                  /*successor=*/RegionBranchPoint::parent(),
103                                  resultLattices);
104   }
105 
106   // The results of a call operation are determined by the callgraph.
107   if (auto call = dyn_cast<CallOpInterface>(op)) {
108     const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
109     // If not all return sites are known, then conservatively assume we can't
110     // reason about the data-flow.
111     if (!predecessors->allPredecessorsKnown())
112       return setAllToEntryStates(resultLattices);
113     for (Operation *predecessor : predecessors->getKnownPredecessors())
114       for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
115         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
116     return;
117   }
118 
119   // Grab the lattice elements of the operands.
120   SmallVector<const AbstractSparseLattice *> operandLattices;
121   operandLattices.reserve(op->getNumOperands());
122   for (Value operand : op->getOperands()) {
123     AbstractSparseLattice *operandLattice = getLatticeElement(operand);
124     operandLattice->useDefSubscribe(this);
125     operandLattices.push_back(operandLattice);
126   }
127 
128   // Invoke the operation transfer function.
129   visitOperationImpl(op, operandLattices, resultLattices);
130 }
131 
132 void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
133   // Exit early on blocks with no arguments.
134   if (block->getNumArguments() == 0)
135     return;
136 
137   // If the block is not executable, bail out.
138   if (!getOrCreate<Executable>(block)->isLive())
139     return;
140 
141   // Get the argument lattices.
142   SmallVector<AbstractSparseLattice *> argLattices;
143   argLattices.reserve(block->getNumArguments());
144   for (BlockArgument argument : block->getArguments()) {
145     AbstractSparseLattice *argLattice = getLatticeElement(argument);
146     argLattices.push_back(argLattice);
147   }
148 
149   // The argument lattices of entry blocks are set by region control-flow or the
150   // callgraph.
151   if (block->isEntryBlock()) {
152     // Check if this block is the entry block of a callable region.
153     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
154     if (callable && callable.getCallableRegion() == block->getParent()) {
155       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
156       // If not all callsites are known, conservatively mark all lattices as
157       // having reached their pessimistic fixpoints.
158       if (!callsites->allPredecessorsKnown())
159         return setAllToEntryStates(argLattices);
160       for (Operation *callsite : callsites->getKnownPredecessors()) {
161         auto call = cast<CallOpInterface>(callsite);
162         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
163           join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
164       }
165       return;
166     }
167 
168     // Check if the lattices can be determined from region control flow.
169     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
170       return visitRegionSuccessors(block, branch, block->getParent(),
171                                    argLattices);
172     }
173 
174     // Otherwise, we can't reason about the data-flow.
175     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
176                                             RegionSuccessor(block->getParent()),
177                                             argLattices, /*firstIndex=*/0);
178   }
179 
180   // Iterate over the predecessors of the non-entry block.
181   for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
182        it != e; ++it) {
183     Block *predecessor = *it;
184 
185     // If the edge from the predecessor block to the current block is not live,
186     // bail out.
187     auto *edgeExecutable =
188         getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
189     edgeExecutable->blockContentSubscribe(this);
190     if (!edgeExecutable->isLive())
191       continue;
192 
193     // Check if we can reason about the data-flow from the predecessor.
194     if (auto branch =
195             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
196       SuccessorOperands operands =
197           branch.getSuccessorOperands(it.getSuccessorIndex());
198       for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
199         if (Value operand = operands[idx]) {
200           join(lattice, *getLatticeElementFor(block, operand));
201         } else {
202           // Conservatively consider internally produced arguments as entry
203           // points.
204           setAllToEntryStates(lattice);
205         }
206       }
207     } else {
208       return setAllToEntryStates(argLattices);
209     }
210   }
211 }
212 
213 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
214     ProgramPoint point, RegionBranchOpInterface branch,
215     RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
216   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
217   assert(predecessors->allPredecessorsKnown() &&
218          "unexpected unresolved region successors");
219 
220   for (Operation *op : predecessors->getKnownPredecessors()) {
221     // Get the incoming successor operands.
222     std::optional<OperandRange> operands;
223 
224     // Check if the predecessor is the parent op.
225     if (op == branch) {
226       operands = branch.getEntrySuccessorOperands(successor);
227       // Otherwise, try to deduce the operands from a region return-like op.
228     } else if (auto regionTerminator =
229                    dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
230       operands = regionTerminator.getSuccessorOperands(successor);
231     }
232 
233     if (!operands) {
234       // We can't reason about the data-flow.
235       return setAllToEntryStates(lattices);
236     }
237 
238     ValueRange inputs = predecessors->getSuccessorInputs(op);
239     assert(inputs.size() == operands->size() &&
240            "expected the same number of successor inputs as operands");
241 
242     unsigned firstIndex = 0;
243     if (inputs.size() != lattices.size()) {
244       if (llvm::dyn_cast_if_present<Operation *>(point)) {
245         if (!inputs.empty())
246           firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
247         visitNonControlFlowArgumentsImpl(
248             branch,
249             RegionSuccessor(
250                 branch->getResults().slice(firstIndex, inputs.size())),
251             lattices, firstIndex);
252       } else {
253         if (!inputs.empty())
254           firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
255         Region *region = point.get<Block *>()->getParent();
256         visitNonControlFlowArgumentsImpl(
257             branch,
258             RegionSuccessor(region, region->getArguments().slice(
259                                         firstIndex, inputs.size())),
260             lattices, firstIndex);
261       }
262     }
263 
264     for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
265       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
266   }
267 }
268 
269 const AbstractSparseLattice *
270 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
271                                                             Value value) {
272   AbstractSparseLattice *state = getLatticeElement(value);
273   addDependency(state, point);
274   return state;
275 }
276 
277 void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
278     ArrayRef<AbstractSparseLattice *> lattices) {
279   for (AbstractSparseLattice *lattice : lattices)
280     setToEntryState(lattice);
281 }
282 
283 void AbstractSparseForwardDataFlowAnalysis::join(
284     AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
285   propagateIfChanged(lhs, lhs->join(rhs));
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // AbstractSparseBackwardDataFlowAnalysis
290 //===----------------------------------------------------------------------===//
291 
292 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
293     DataFlowSolver &solver, SymbolTableCollection &symbolTable)
294     : DataFlowAnalysis(solver), symbolTable(symbolTable) {
295   registerPointKind<CFGEdge>();
296 }
297 
298 LogicalResult
299 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
300   return initializeRecursively(top);
301 }
302 
303 LogicalResult
304 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
305   visitOperation(op);
306   for (Region &region : op->getRegions()) {
307     for (Block &block : region) {
308       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
309       // Initialize ops in reverse order, so we can do as much initial
310       // propagation as possible without having to go through the
311       // solver queue.
312       for (auto it = block.rbegin(); it != block.rend(); it++)
313         if (failed(initializeRecursively(&*it)))
314           return failure();
315     }
316   }
317   return success();
318 }
319 
320 LogicalResult
321 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
322   if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
323     visitOperation(op);
324   else if (llvm::dyn_cast_if_present<Block *>(point))
325     // For backward dataflow, we don't have to do any work for the blocks
326     // themselves. CFG edges between blocks are processed by the BranchOp
327     // logic in `visitOperation`, and entry blocks for functions are tied
328     // to the CallOp arguments by visitOperation.
329     return success();
330   else
331     return failure();
332   return success();
333 }
334 
335 SmallVector<AbstractSparseLattice *>
336 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
337   SmallVector<AbstractSparseLattice *> resultLattices;
338   resultLattices.reserve(values.size());
339   for (Value result : values) {
340     AbstractSparseLattice *resultLattice = getLatticeElement(result);
341     resultLattices.push_back(resultLattice);
342   }
343   return resultLattices;
344 }
345 
346 SmallVector<const AbstractSparseLattice *>
347 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
348     ProgramPoint point, ValueRange values) {
349   SmallVector<const AbstractSparseLattice *> resultLattices;
350   resultLattices.reserve(values.size());
351   for (Value result : values) {
352     const AbstractSparseLattice *resultLattice =
353         getLatticeElementFor(point, result);
354     resultLattices.push_back(resultLattice);
355   }
356   return resultLattices;
357 }
358 
359 static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
360   return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
361 }
362 
363 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
364   // If we're in a dead block, bail out.
365   if (!getOrCreate<Executable>(op->getBlock())->isLive())
366     return;
367 
368   SmallVector<AbstractSparseLattice *> operandLattices =
369       getLatticeElements(op->getOperands());
370   SmallVector<const AbstractSparseLattice *> resultLattices =
371       getLatticeElementsFor(op, op->getResults());
372 
373   // Block arguments of region branch operations flow back into the operands
374   // of the parent op
375   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
376     visitRegionSuccessors(branch, operandLattices);
377     return;
378   }
379 
380   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
381     // Block arguments of successor blocks flow back into our operands.
382 
383     // We remember all operands not forwarded to any block in a BitVector.
384     // We can't just cut out a range here, since the non-forwarded ops might
385     // be non-contiguous (if there's more than one successor).
386     BitVector unaccounted(op->getNumOperands(), true);
387 
388     for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
389       SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
390       OperandRange forwarded = successorOperands.getForwardedOperands();
391       if (!forwarded.empty()) {
392         MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
393             forwarded.getBeginOperandIndex(), forwarded.size());
394         for (OpOperand &operand : operands) {
395           unaccounted.reset(operand.getOperandNumber());
396           if (std::optional<BlockArgument> blockArg =
397                   detail::getBranchSuccessorArgument(
398                       successorOperands, operand.getOperandNumber(), block)) {
399             meet(getLatticeElement(operand.get()),
400                  *getLatticeElementFor(op, *blockArg));
401           }
402         }
403       }
404     }
405     // Operands not forwarded to successor blocks are typically parameters
406     // of the branch operation itself (for example the boolean for if/else).
407     for (int index : unaccounted.set_bits()) {
408       OpOperand &operand = op->getOpOperand(index);
409       visitBranchOperand(operand);
410     }
411     return;
412   }
413 
414   // For function calls, connect the arguments of the entry blocks to the
415   // operands of the call op that are forwarded to these arguments.
416   if (auto call = dyn_cast<CallOpInterface>(op)) {
417     Operation *callableOp = call.resolveCallable(&symbolTable);
418     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
419       // Not all operands of a call op forward to arguments. Such operands are
420       // stored in `unaccounted`.
421       BitVector unaccounted(op->getNumOperands(), true);
422 
423       OperandRange argOperands = call.getArgOperands();
424       MutableArrayRef<OpOperand> argOpOperands =
425           operandsToOpOperands(argOperands);
426       Region *region = callable.getCallableRegion();
427       if (region && !region->empty()) {
428         Block &block = region->front();
429         for (auto [blockArg, argOpOperand] :
430              llvm::zip(block.getArguments(), argOpOperands)) {
431           meet(getLatticeElement(argOpOperand.get()),
432                *getLatticeElementFor(op, blockArg));
433           unaccounted.reset(argOpOperand.getOperandNumber());
434         }
435       }
436       // Handle the operands of the call op that aren't forwarded to any
437       // arguments.
438       for (int index : unaccounted.set_bits()) {
439         OpOperand &opOperand = op->getOpOperand(index);
440         visitCallOperand(opOperand);
441       }
442       return;
443     }
444   }
445 
446   // When the region of an op implementing `RegionBranchOpInterface` has a
447   // terminator implementing `RegionBranchTerminatorOpInterface` or a
448   // return-like terminator, the region's successors' arguments flow back into
449   // the "successor operands" of this terminator.
450   //
451   // A successor operand with respect to an op implementing
452   // `RegionBranchOpInterface` is an operand that is forwarded to a region
453   // successor's input. There are two types of successor operands: the operands
454   // of this op itself and the operands of the terminators of the regions of
455   // this op.
456   if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
457     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
458       visitRegionSuccessorsFromTerminator(terminator, branch);
459       return;
460     }
461   }
462 
463   if (op->hasTrait<OpTrait::ReturnLike>()) {
464     // Going backwards, the operands of the return are derived from the
465     // results of all CallOps calling this CallableOp.
466     if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
467       const PredecessorState *callsites =
468           getOrCreateFor<PredecessorState>(op, callable);
469       if (callsites->allPredecessorsKnown()) {
470         for (Operation *call : callsites->getKnownPredecessors()) {
471           SmallVector<const AbstractSparseLattice *> callResultLattices =
472               getLatticeElementsFor(op, call->getResults());
473           for (auto [op, result] :
474                llvm::zip(operandLattices, callResultLattices))
475             meet(op, *result);
476         }
477       } else {
478         // If we don't know all the callers, we can't know where the
479         // returned values go. Note that, in particular, this will trigger
480         // for the return ops of any public functions.
481         setAllToExitStates(operandLattices);
482       }
483       return;
484     }
485   }
486 
487   visitOperationImpl(op, operandLattices, resultLattices);
488 }
489 
490 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
491     RegionBranchOpInterface branch,
492     ArrayRef<AbstractSparseLattice *> operandLattices) {
493   Operation *op = branch.getOperation();
494   SmallVector<RegionSuccessor> successors;
495   SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
496   branch.getEntrySuccessorRegions(operands, successors);
497 
498   // All operands not forwarded to any successor. This set can be non-contiguous
499   // in the presence of multiple successors.
500   BitVector unaccounted(op->getNumOperands(), true);
501 
502   for (RegionSuccessor &successor : successors) {
503     OperandRange operands = branch.getEntrySuccessorOperands(successor);
504     MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
505     ValueRange inputs = successor.getSuccessorInputs();
506     for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
507       meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
508       unaccounted.reset(operand.getOperandNumber());
509     }
510   }
511   // All operands not forwarded to regions are typically parameters of the
512   // branch operation itself (for example the boolean for if/else).
513   for (int index : unaccounted.set_bits()) {
514     visitBranchOperand(op->getOpOperand(index));
515   }
516 }
517 
518 void AbstractSparseBackwardDataFlowAnalysis::
519     visitRegionSuccessorsFromTerminator(
520         RegionBranchTerminatorOpInterface terminator,
521         RegionBranchOpInterface branch) {
522   assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
523          "expected a `RegionBranchTerminatorOpInterface` op");
524   assert(terminator->getParentOp() == branch.getOperation() &&
525          "expected `branch` to be the parent op of `terminator`");
526 
527   SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
528                                            nullptr);
529   SmallVector<RegionSuccessor> successors;
530   terminator.getSuccessorRegions(operandAttributes, successors);
531   // All operands not forwarded to any successor. This set can be
532   // non-contiguous in the presence of multiple successors.
533   BitVector unaccounted(terminator->getNumOperands(), true);
534 
535   for (const RegionSuccessor &successor : successors) {
536     ValueRange inputs = successor.getSuccessorInputs();
537     OperandRange operands = terminator.getSuccessorOperands(successor);
538     MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
539     for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
540       meet(getLatticeElement(opOperand.get()),
541            *getLatticeElementFor(terminator, input));
542       unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
543     }
544   }
545   // Visit operands of the branch op not forwarded to the next region.
546   // (Like e.g. the boolean of `scf.conditional`)
547   for (int index : unaccounted.set_bits()) {
548     visitBranchOperand(terminator->getOpOperand(index));
549   }
550 }
551 
552 const AbstractSparseLattice *
553 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
554                                                              Value value) {
555   AbstractSparseLattice *state = getLatticeElement(value);
556   addDependency(state, point);
557   return state;
558 }
559 
560 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
561     ArrayRef<AbstractSparseLattice *> lattices) {
562   for (AbstractSparseLattice *lattice : lattices)
563     setToExitState(lattice);
564 }
565 
566 void AbstractSparseBackwardDataFlowAnalysis::meet(
567     AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
568   propagateIfChanged(lhs, lhs->meet(rhs));
569 }
570