xref: /llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (revision 3c7c696a521c8df5a27b26af0aee8a63d5475e6b)
1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 
13 using namespace mlir;
14 using namespace mlir::dataflow;
15 
16 //===----------------------------------------------------------------------===//
17 // AbstractSparseLattice
18 //===----------------------------------------------------------------------===//
19 
20 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
21   // Push all users of the value to the queue.
22   for (Operation *user : point.get<Value>().getUsers())
23     for (DataFlowAnalysis *analysis : useDefSubscribers)
24       solver->enqueue({user, analysis});
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // AbstractSparseDataFlowAnalysis
29 //===----------------------------------------------------------------------===//
30 
31 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
32     DataFlowSolver &solver)
33     : DataFlowAnalysis(solver) {
34   registerPointKind<CFGEdge>();
35 }
36 
37 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
38   // Mark the entry block arguments as having reached their pessimistic
39   // fixpoints.
40   for (Region &region : top->getRegions()) {
41     if (region.empty())
42       continue;
43     for (Value argument : region.front().getArguments())
44       setToEntryState(getLatticeElement(argument));
45   }
46 
47   return initializeRecursively(top);
48 }
49 
50 LogicalResult
51 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
52   // Initialize the analysis by visiting every owner of an SSA value (all
53   // operations and blocks).
54   visitOperation(op);
55   for (Region &region : op->getRegions()) {
56     for (Block &block : region) {
57       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
58       visitBlock(&block);
59       for (Operation &op : block)
60         if (failed(initializeRecursively(&op)))
61           return failure();
62     }
63   }
64 
65   return success();
66 }
67 
68 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
69   if (Operation *op = point.dyn_cast<Operation *>())
70     visitOperation(op);
71   else if (Block *block = point.dyn_cast<Block *>())
72     visitBlock(block);
73   else
74     return failure();
75   return success();
76 }
77 
78 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
79   // Exit early on operations with no results.
80   if (op->getNumResults() == 0)
81     return;
82 
83   // If the containing block is not executable, bail out.
84   if (!getOrCreate<Executable>(op->getBlock())->isLive())
85     return;
86 
87   // Get the result lattices.
88   SmallVector<AbstractSparseLattice *> resultLattices;
89   resultLattices.reserve(op->getNumResults());
90   for (Value result : op->getResults()) {
91     AbstractSparseLattice *resultLattice = getLatticeElement(result);
92     resultLattices.push_back(resultLattice);
93   }
94 
95   // The results of a region branch operation are determined by control-flow.
96   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
97     return visitRegionSuccessors({branch}, branch,
98                                  /*successorIndex=*/std::nullopt,
99                                  resultLattices);
100   }
101 
102   // The results of a call operation are determined by the callgraph.
103   if (auto call = dyn_cast<CallOpInterface>(op)) {
104     const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
105     // If not all return sites are known, then conservatively assume we can't
106     // reason about the data-flow.
107     if (!predecessors->allPredecessorsKnown())
108       return setAllToEntryStates(resultLattices);
109     for (Operation *predecessor : predecessors->getKnownPredecessors())
110       for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
111         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
112     return;
113   }
114 
115   // Grab the lattice elements of the operands.
116   SmallVector<const AbstractSparseLattice *> operandLattices;
117   operandLattices.reserve(op->getNumOperands());
118   for (Value operand : op->getOperands()) {
119     AbstractSparseLattice *operandLattice = getLatticeElement(operand);
120     operandLattice->useDefSubscribe(this);
121     operandLattices.push_back(operandLattice);
122   }
123 
124   // Invoke the operation transfer function.
125   visitOperationImpl(op, operandLattices, resultLattices);
126 }
127 
128 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
129   // Exit early on blocks with no arguments.
130   if (block->getNumArguments() == 0)
131     return;
132 
133   // If the block is not executable, bail out.
134   if (!getOrCreate<Executable>(block)->isLive())
135     return;
136 
137   // Get the argument lattices.
138   SmallVector<AbstractSparseLattice *> argLattices;
139   argLattices.reserve(block->getNumArguments());
140   for (BlockArgument argument : block->getArguments()) {
141     AbstractSparseLattice *argLattice = getLatticeElement(argument);
142     argLattices.push_back(argLattice);
143   }
144 
145   // The argument lattices of entry blocks are set by region control-flow or the
146   // callgraph.
147   if (block->isEntryBlock()) {
148     // Check if this block is the entry block of a callable region.
149     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
150     if (callable && callable.getCallableRegion() == block->getParent()) {
151       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
152       // If not all callsites are known, conservatively mark all lattices as
153       // having reached their pessimistic fixpoints.
154       if (!callsites->allPredecessorsKnown())
155         return setAllToEntryStates(argLattices);
156       for (Operation *callsite : callsites->getKnownPredecessors()) {
157         auto call = cast<CallOpInterface>(callsite);
158         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
159           join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
160       }
161       return;
162     }
163 
164     // Check if the lattices can be determined from region control flow.
165     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
166       return visitRegionSuccessors(
167           block, branch, block->getParent()->getRegionNumber(), argLattices);
168     }
169 
170     // Otherwise, we can't reason about the data-flow.
171     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
172                                             RegionSuccessor(block->getParent()),
173                                             argLattices, /*firstIndex=*/0);
174   }
175 
176   // Iterate over the predecessors of the non-entry block.
177   for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
178        it != e; ++it) {
179     Block *predecessor = *it;
180 
181     // If the edge from the predecessor block to the current block is not live,
182     // bail out.
183     auto *edgeExecutable =
184         getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
185     edgeExecutable->blockContentSubscribe(this);
186     if (!edgeExecutable->isLive())
187       continue;
188 
189     // Check if we can reason about the data-flow from the predecessor.
190     if (auto branch =
191             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
192       SuccessorOperands operands =
193           branch.getSuccessorOperands(it.getSuccessorIndex());
194       for (auto &it : llvm::enumerate(argLattices)) {
195         if (Value operand = operands[it.index()]) {
196           join(it.value(), *getLatticeElementFor(block, operand));
197         } else {
198           // Conservatively consider internally produced arguments as entry
199           // points.
200           setAllToEntryStates(it.value());
201         }
202       }
203     } else {
204       return setAllToEntryStates(argLattices);
205     }
206   }
207 }
208 
209 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
210     ProgramPoint point, RegionBranchOpInterface branch,
211     std::optional<unsigned> successorIndex,
212     ArrayRef<AbstractSparseLattice *> lattices) {
213   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
214   assert(predecessors->allPredecessorsKnown() &&
215          "unexpected unresolved region successors");
216 
217   for (Operation *op : predecessors->getKnownPredecessors()) {
218     // Get the incoming successor operands.
219     std::optional<OperandRange> operands;
220 
221     // Check if the predecessor is the parent op.
222     if (op == branch) {
223       operands = branch.getSuccessorEntryOperands(successorIndex);
224       // Otherwise, try to deduce the operands from a region return-like op.
225     } else {
226       if (isRegionReturnLike(op))
227         operands = getRegionBranchSuccessorOperands(op, successorIndex);
228     }
229 
230     if (!operands) {
231       // We can't reason about the data-flow.
232       return setAllToEntryStates(lattices);
233     }
234 
235     ValueRange inputs = predecessors->getSuccessorInputs(op);
236     assert(inputs.size() == operands->size() &&
237            "expected the same number of successor inputs as operands");
238 
239     unsigned firstIndex = 0;
240     if (inputs.size() != lattices.size()) {
241       if (point.dyn_cast<Operation *>()) {
242         if (!inputs.empty())
243           firstIndex = inputs.front().cast<OpResult>().getResultNumber();
244         visitNonControlFlowArgumentsImpl(
245             branch,
246             RegionSuccessor(
247                 branch->getResults().slice(firstIndex, inputs.size())),
248             lattices, firstIndex);
249       } else {
250         if (!inputs.empty())
251           firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
252         Region *region = point.get<Block *>()->getParent();
253         visitNonControlFlowArgumentsImpl(
254             branch,
255             RegionSuccessor(region, region->getArguments().slice(
256                                         firstIndex, inputs.size())),
257             lattices, firstIndex);
258       }
259     }
260 
261     for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
262       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
263   }
264 }
265 
266 const AbstractSparseLattice *
267 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
268                                                      Value value) {
269   AbstractSparseLattice *state = getLatticeElement(value);
270   addDependency(state, point);
271   return state;
272 }
273 
274 void AbstractSparseDataFlowAnalysis::setAllToEntryStates(
275     ArrayRef<AbstractSparseLattice *> lattices) {
276   for (AbstractSparseLattice *lattice : lattices)
277     setToEntryState(lattice);
278 }
279 
280 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
281                                           const AbstractSparseLattice &rhs) {
282   propagateIfChanged(lhs, lhs->join(rhs));
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // AbstractSparseBackwardDataFlowAnalysis
287 //===----------------------------------------------------------------------===//
288 
289 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
290     DataFlowSolver &solver, SymbolTableCollection &symbolTable)
291     : DataFlowAnalysis(solver), symbolTable(symbolTable) {
292   registerPointKind<CFGEdge>();
293 }
294 
295 LogicalResult
296 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
297   return initializeRecursively(top);
298 }
299 
300 LogicalResult
301 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
302   visitOperation(op);
303   for (Region &region : op->getRegions()) {
304     for (Block &block : region) {
305       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
306       // Initialize ops in reverse order, so we can do as much initial
307       // propagation as possible without having to go through the
308       // solver queue.
309       for (auto it = block.rbegin(); it != block.rend(); it++)
310         if (failed(initializeRecursively(&*it)))
311           return failure();
312     }
313   }
314   return success();
315 }
316 
317 LogicalResult
318 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
319   if (Operation *op = point.dyn_cast<Operation *>())
320     visitOperation(op);
321   else if (point.dyn_cast<Block *>())
322     // For backward dataflow, we don't have to do any work for the blocks
323     // themselves. CFG edges between blocks are processed by the BranchOp
324     // logic in `visitOperation`, and entry blocks for functions are tied
325     // to the CallOp arguments by visitOperation.
326     return success();
327   else
328     return failure();
329   return success();
330 }
331 
332 SmallVector<AbstractSparseLattice *>
333 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
334   SmallVector<AbstractSparseLattice *> resultLattices;
335   resultLattices.reserve(values.size());
336   for (Value result : values) {
337     AbstractSparseLattice *resultLattice = getLatticeElement(result);
338     resultLattices.push_back(resultLattice);
339   }
340   return resultLattices;
341 }
342 
343 SmallVector<const AbstractSparseLattice *>
344 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
345     ProgramPoint point, ValueRange values) {
346   SmallVector<const AbstractSparseLattice *> resultLattices;
347   resultLattices.reserve(values.size());
348   for (Value result : values) {
349     const AbstractSparseLattice *resultLattice =
350         getLatticeElementFor(point, result);
351     resultLattices.push_back(resultLattice);
352   }
353   return resultLattices;
354 }
355 
356 static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
357   return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
358 }
359 
360 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
361   // If we're in a dead block, bail out.
362   if (!getOrCreate<Executable>(op->getBlock())->isLive())
363     return;
364 
365   SmallVector<AbstractSparseLattice *> operandLattices =
366       getLatticeElements(op->getOperands());
367   SmallVector<const AbstractSparseLattice *> resultLattices =
368       getLatticeElementsFor(op, op->getResults());
369 
370   // Block arguments of region branch operations flow back into the operands
371   // of the parent op
372   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
373     visitRegionSuccessors(branch, operandLattices);
374     return;
375   }
376 
377   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
378     // Block arguments of successor blocks flow back into our operands.
379 
380     // We remember all operands not forwarded to any block in a BitVector.
381     // We can't just cut out a range here, since the non-forwarded ops might
382     // be non-contiguous (if there's more than one successor).
383     BitVector unaccounted(op->getNumOperands(), true);
384 
385     for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
386       SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
387       OperandRange forwarded = successorOperands.getForwardedOperands();
388       if (!forwarded.empty()) {
389         MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
390             forwarded.getBeginOperandIndex(), forwarded.size());
391         for (OpOperand &operand : operands) {
392           unaccounted.reset(operand.getOperandNumber());
393           if (std::optional<BlockArgument> blockArg =
394                   detail::getBranchSuccessorArgument(
395                       successorOperands, operand.getOperandNumber(), block)) {
396             meet(getLatticeElement(operand.get()),
397                  *getLatticeElementFor(op, *blockArg));
398           }
399         }
400       }
401     }
402     // Operands not forwarded to successor blocks are typically parameters
403     // of the branch operation itself (for example the boolean for if/else).
404     for (int index : unaccounted.set_bits()) {
405       OpOperand &operand = op->getOpOperand(index);
406       visitBranchOperand(operand);
407     }
408     return;
409   }
410 
411   // For function calls, connect the arguments of the entry blocks
412   // to the operands of the call op.
413   if (auto call = dyn_cast<CallOpInterface>(op)) {
414     Operation *callableOp = call.resolveCallable(&symbolTable);
415     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
416       Region *region = callable.getCallableRegion();
417       if (!region->empty()) {
418         Block &block = region->front();
419         for (auto [blockArg, operand] :
420              llvm::zip(block.getArguments(), operandLattices)) {
421           meet(operand, *getLatticeElementFor(op, blockArg));
422         }
423       }
424       return;
425     }
426   }
427 
428   // The block arguments of the branched to region flow back into the
429   // operands of the yield operation.
430   if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
431     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
432       SmallVector<RegionSuccessor> successors;
433       SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
434       branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
435                                  operands, successors);
436       // All operands not forwarded to any successor. This set can be
437       // non-contiguous in the presence of multiple successors.
438       BitVector unaccounted(op->getNumOperands(), true);
439 
440       for (const RegionSuccessor &successor : successors) {
441         ValueRange inputs = successor.getSuccessorInputs();
442         Region *region = successor.getSuccessor();
443         OperandRange operands =
444             region ? terminator.getSuccessorOperands(region->getRegionNumber())
445                    : terminator.getSuccessorOperands({});
446         MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
447         for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) {
448           meet(getLatticeElement(opoperand.get()),
449                *getLatticeElementFor(op, input));
450           unaccounted.reset(
451               const_cast<OpOperand &>(opoperand).getOperandNumber());
452         }
453       }
454       // Visit operands of the branch op not forwarded to the next region.
455       // (Like e.g. the boolean of `scf.conditional`)
456       for (int index : unaccounted.set_bits()) {
457         visitBranchOperand(op->getOpOperand(index));
458       }
459       return;
460     }
461   }
462 
463   // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`,
464   // since they behave like a return in the sense that they forward to the
465   // results of some other (here: the parent) op.
466   if (op->hasTrait<OpTrait::ReturnLike>()) {
467     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
468       OperandRange operands = op->getOperands();
469       ResultRange results = op->getParentOp()->getResults();
470       assert(results.size() == operands.size() &&
471              "Can't derive arg mapping for yield-like op.");
472       for (auto [operand, result] : llvm::zip(operands, results))
473         meet(getLatticeElement(operand), *getLatticeElementFor(op, result));
474       return;
475     }
476 
477     // Going backwards, the operands of the return are derived from the
478     // results of all CallOps calling this CallableOp.
479     if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
480       const PredecessorState *callsites =
481           getOrCreateFor<PredecessorState>(op, callable);
482       if (callsites->allPredecessorsKnown()) {
483         for (Operation *call : callsites->getKnownPredecessors()) {
484           SmallVector<const AbstractSparseLattice *> callResultLattices =
485               getLatticeElementsFor(op, call->getResults());
486           for (auto [op, result] :
487                llvm::zip(operandLattices, callResultLattices))
488             meet(op, *result);
489         }
490       } else {
491         // If we don't know all the callers, we can't know where the
492         // returned values go. Note that, in particular, this will trigger
493         // for the return ops of any public functions.
494         setAllToExitStates(operandLattices);
495       }
496       return;
497     }
498   }
499 
500   visitOperationImpl(op, operandLattices, resultLattices);
501 }
502 
503 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
504     RegionBranchOpInterface branch,
505     ArrayRef<AbstractSparseLattice *> operandLattices) {
506   Operation *op = branch.getOperation();
507   SmallVector<RegionSuccessor> successors;
508   SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
509   branch.getSuccessorRegions(/*index=*/{}, operands, successors);
510 
511   // All operands not forwarded to any successor. This set can be non-contiguous
512   // in the presence of multiple successors.
513   BitVector unaccounted(op->getNumOperands(), true);
514 
515   for (RegionSuccessor &successor : successors) {
516     Region *region = successor.getSuccessor();
517     OperandRange operands =
518         region ? branch.getSuccessorEntryOperands(region->getRegionNumber())
519                : branch.getSuccessorEntryOperands({});
520     MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
521     ValueRange inputs = successor.getSuccessorInputs();
522     for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
523       meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
524       unaccounted.reset(operand.getOperandNumber());
525     }
526   }
527   // All operands not forwarded to regions are typically parameters of the
528   // branch operation itself (for example the boolean for if/else).
529   for (int index : unaccounted.set_bits()) {
530     visitBranchOperand(op->getOpOperand(index));
531   }
532 }
533 
534 const AbstractSparseLattice *
535 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
536                                                              Value value) {
537   AbstractSparseLattice *state = getLatticeElement(value);
538   addDependency(state, point);
539   return state;
540 }
541 
542 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
543     ArrayRef<AbstractSparseLattice *> lattices) {
544   for (AbstractSparseLattice *lattice : lattices)
545     setToExitState(lattice);
546 }
547 
548 void AbstractSparseBackwardDataFlowAnalysis::meet(
549     AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
550   propagateIfChanged(lhs, lhs->meet(rhs));
551 }
552