xref: /llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (revision b5c5c2b26fd4bd0d0d237aaf77a01ca528810707)
1c095afcbSMogball //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2c095afcbSMogball //
3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information.
5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c095afcbSMogball //
7c095afcbSMogball //===----------------------------------------------------------------------===//
8c095afcbSMogball 
9c095afcbSMogball #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
109432fbfeSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
116a666737SZhixun Tan #include "mlir/Analysis/DataFlowFramework.h"
12ebc2c4bdSMehdi Amini #include "mlir/IR/Attributes.h"
13ebc2c4bdSMehdi Amini #include "mlir/IR/Operation.h"
14ebc2c4bdSMehdi Amini #include "mlir/IR/Region.h"
15ebc2c4bdSMehdi Amini #include "mlir/IR/SymbolTable.h"
16ebc2c4bdSMehdi Amini #include "mlir/IR/Value.h"
17ebc2c4bdSMehdi Amini #include "mlir/IR/ValueRange.h"
189432fbfeSMogball #include "mlir/Interfaces/CallInterfaces.h"
19ebc2c4bdSMehdi Amini #include "mlir/Interfaces/ControlFlowInterfaces.h"
20ebc2c4bdSMehdi Amini #include "mlir/Support/LLVM.h"
21ebc2c4bdSMehdi Amini #include "llvm/ADT/STLExtras.h"
22ebc2c4bdSMehdi Amini #include "llvm/Support/Casting.h"
23ebc2c4bdSMehdi Amini #include <cassert>
24ebc2c4bdSMehdi Amini #include <optional>
25c095afcbSMogball 
26c095afcbSMogball using namespace mlir;
27c095afcbSMogball using namespace mlir::dataflow;
28c095afcbSMogball 
29c095afcbSMogball //===----------------------------------------------------------------------===//
30c095afcbSMogball // AbstractSparseLattice
31c095afcbSMogball //===----------------------------------------------------------------------===//
32c095afcbSMogball 
33c095afcbSMogball void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
346a666737SZhixun Tan   AnalysisState::onUpdate(solver);
356a666737SZhixun Tan 
36c095afcbSMogball   // Push all users of the value to the queue.
37*b5c5c2b2SKazu Hirata   for (Operation *user : cast<Value>(anchor).getUsers())
38c095afcbSMogball     for (DataFlowAnalysis *analysis : useDefSubscribers)
394b3f251bSdonald chen       solver->enqueue({solver->getProgramPointAfter(user), analysis});
40c095afcbSMogball }
419432fbfeSMogball 
429432fbfeSMogball //===----------------------------------------------------------------------===//
43b2b7efb9SAlex Zinenko // AbstractSparseForwardDataFlowAnalysis
449432fbfeSMogball //===----------------------------------------------------------------------===//
459432fbfeSMogball 
46b2b7efb9SAlex Zinenko AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
479432fbfeSMogball     DataFlowSolver &solver)
489432fbfeSMogball     : DataFlowAnalysis(solver) {
49b6603e1bSdonald chen   registerAnchorKind<CFGEdge>();
509432fbfeSMogball }
519432fbfeSMogball 
52b2b7efb9SAlex Zinenko LogicalResult
53b2b7efb9SAlex Zinenko AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
549432fbfeSMogball   // Mark the entry block arguments as having reached their pessimistic
559432fbfeSMogball   // fixpoints.
569432fbfeSMogball   for (Region &region : top->getRegions()) {
579432fbfeSMogball     if (region.empty())
589432fbfeSMogball       continue;
599432fbfeSMogball     for (Value argument : region.front().getArguments())
604e98d611SMatthias Kramm       setToEntryState(getLatticeElement(argument));
619432fbfeSMogball   }
629432fbfeSMogball 
639432fbfeSMogball   return initializeRecursively(top);
649432fbfeSMogball }
659432fbfeSMogball 
669432fbfeSMogball LogicalResult
67b2b7efb9SAlex Zinenko AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
689432fbfeSMogball   // Initialize the analysis by visiting every owner of an SSA value (all
699432fbfeSMogball   // operations and blocks).
7015e915a4SIvan Butygin   if (failed(visitOperation(op)))
7115e915a4SIvan Butygin     return failure();
7215e915a4SIvan Butygin 
739432fbfeSMogball   for (Region &region : op->getRegions()) {
749432fbfeSMogball     for (Block &block : region) {
754b3f251bSdonald chen       getOrCreate<Executable>(getProgramPointBefore(&block))
764b3f251bSdonald chen           ->blockContentSubscribe(this);
779432fbfeSMogball       visitBlock(&block);
789432fbfeSMogball       for (Operation &op : block)
799432fbfeSMogball         if (failed(initializeRecursively(&op)))
809432fbfeSMogball           return failure();
819432fbfeSMogball     }
829432fbfeSMogball   }
839432fbfeSMogball 
849432fbfeSMogball   return success();
859432fbfeSMogball }
869432fbfeSMogball 
874b3f251bSdonald chen LogicalResult
884b3f251bSdonald chen AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
894b3f251bSdonald chen   if (!point->isBlockStart())
904b3f251bSdonald chen     return visitOperation(point->getPrevOp());
914b3f251bSdonald chen   visitBlock(point->getBlock());
929432fbfeSMogball   return success();
939432fbfeSMogball }
949432fbfeSMogball 
9515e915a4SIvan Butygin LogicalResult
9615e915a4SIvan Butygin AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
979432fbfeSMogball   // Exit early on operations with no results.
989432fbfeSMogball   if (op->getNumResults() == 0)
9915e915a4SIvan Butygin     return success();
1009432fbfeSMogball 
1019432fbfeSMogball   // If the containing block is not executable, bail out.
1024b3f251bSdonald chen   if (op->getBlock() != nullptr &&
1034b3f251bSdonald chen       !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
10415e915a4SIvan Butygin     return success();
1059432fbfeSMogball 
1069432fbfeSMogball   // Get the result lattices.
1079432fbfeSMogball   SmallVector<AbstractSparseLattice *> resultLattices;
1089432fbfeSMogball   resultLattices.reserve(op->getNumResults());
1099432fbfeSMogball   for (Value result : op->getResults()) {
1109432fbfeSMogball     AbstractSparseLattice *resultLattice = getLatticeElement(result);
1119432fbfeSMogball     resultLattices.push_back(resultLattice);
1129432fbfeSMogball   }
1139432fbfeSMogball 
1149432fbfeSMogball   // The results of a region branch operation are determined by control-flow.
1159432fbfeSMogball   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
1164b3f251bSdonald chen     visitRegionSuccessors(getProgramPointAfter(branch), branch,
1174dd744acSMarkus Böck                           /*successor=*/RegionBranchPoint::parent(),
1181a36588eSKazu Hirata                           resultLattices);
11915e915a4SIvan Butygin     return success();
1209432fbfeSMogball   }
1219432fbfeSMogball 
12232a4e3fcSOleksandr "Alex" Zinenko   // Grab the lattice elements of the operands.
12332a4e3fcSOleksandr "Alex" Zinenko   SmallVector<const AbstractSparseLattice *> operandLattices;
12432a4e3fcSOleksandr "Alex" Zinenko   operandLattices.reserve(op->getNumOperands());
12532a4e3fcSOleksandr "Alex" Zinenko   for (Value operand : op->getOperands()) {
12632a4e3fcSOleksandr "Alex" Zinenko     AbstractSparseLattice *operandLattice = getLatticeElement(operand);
12732a4e3fcSOleksandr "Alex" Zinenko     operandLattice->useDefSubscribe(this);
12832a4e3fcSOleksandr "Alex" Zinenko     operandLattices.push_back(operandLattice);
12932a4e3fcSOleksandr "Alex" Zinenko   }
13032a4e3fcSOleksandr "Alex" Zinenko 
1319432fbfeSMogball   if (auto call = dyn_cast<CallOpInterface>(op)) {
13232a4e3fcSOleksandr "Alex" Zinenko     // If the call operation is to an external function, attempt to infer the
13332a4e3fcSOleksandr "Alex" Zinenko     // results from the call arguments.
13432a4e3fcSOleksandr "Alex" Zinenko     auto callable =
13532a4e3fcSOleksandr "Alex" Zinenko         dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
13632a4e3fcSOleksandr "Alex" Zinenko     if (!getSolverConfig().isInterprocedural() ||
13732a4e3fcSOleksandr "Alex" Zinenko         (callable && !callable.getCallableRegion())) {
13815e915a4SIvan Butygin       visitExternalCallImpl(call, operandLattices, resultLattices);
13915e915a4SIvan Butygin       return success();
14032a4e3fcSOleksandr "Alex" Zinenko     }
14132a4e3fcSOleksandr "Alex" Zinenko 
14232a4e3fcSOleksandr "Alex" Zinenko     // Otherwise, the results of a call operation are determined by the
14332a4e3fcSOleksandr "Alex" Zinenko     // callgraph.
1444b3f251bSdonald chen     const auto *predecessors = getOrCreateFor<PredecessorState>(
1454b3f251bSdonald chen         getProgramPointAfter(op), getProgramPointAfter(call));
1469432fbfeSMogball     // If not all return sites are known, then conservatively assume we can't
1479432fbfeSMogball     // reason about the data-flow.
14815e915a4SIvan Butygin     if (!predecessors->allPredecessorsKnown()) {
14915e915a4SIvan Butygin       setAllToEntryStates(resultLattices);
15015e915a4SIvan Butygin       return success();
15115e915a4SIvan Butygin     }
1529432fbfeSMogball     for (Operation *predecessor : predecessors->getKnownPredecessors())
15315e915a4SIvan Butygin       for (auto &&[operand, resLattice] :
15415e915a4SIvan Butygin            llvm::zip(predecessor->getOperands(), resultLattices))
1554b3f251bSdonald chen         join(resLattice,
1564b3f251bSdonald chen              *getLatticeElementFor(getProgramPointAfter(op), operand));
15715e915a4SIvan Butygin     return success();
1589432fbfeSMogball   }
1599432fbfeSMogball 
1609432fbfeSMogball   // Invoke the operation transfer function.
16115e915a4SIvan Butygin   return visitOperationImpl(op, operandLattices, resultLattices);
1629432fbfeSMogball }
1639432fbfeSMogball 
164b2b7efb9SAlex Zinenko void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
1659432fbfeSMogball   // Exit early on blocks with no arguments.
1669432fbfeSMogball   if (block->getNumArguments() == 0)
1679432fbfeSMogball     return;
1689432fbfeSMogball 
1699432fbfeSMogball   // If the block is not executable, bail out.
1704b3f251bSdonald chen   if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive())
1719432fbfeSMogball     return;
1729432fbfeSMogball 
1739432fbfeSMogball   // Get the argument lattices.
1749432fbfeSMogball   SmallVector<AbstractSparseLattice *> argLattices;
1759432fbfeSMogball   argLattices.reserve(block->getNumArguments());
1769432fbfeSMogball   for (BlockArgument argument : block->getArguments()) {
1779432fbfeSMogball     AbstractSparseLattice *argLattice = getLatticeElement(argument);
1789432fbfeSMogball     argLattices.push_back(argLattice);
1799432fbfeSMogball   }
1809432fbfeSMogball 
1819432fbfeSMogball   // The argument lattices of entry blocks are set by region control-flow or the
1829432fbfeSMogball   // callgraph.
1839432fbfeSMogball   if (block->isEntryBlock()) {
1849432fbfeSMogball     // Check if this block is the entry block of a callable region.
1859432fbfeSMogball     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
1869432fbfeSMogball     if (callable && callable.getCallableRegion() == block->getParent()) {
1874b3f251bSdonald chen       const auto *callsites = getOrCreateFor<PredecessorState>(
1884b3f251bSdonald chen           getProgramPointBefore(block), getProgramPointAfter(callable));
1899432fbfeSMogball       // If not all callsites are known, conservatively mark all lattices as
1909432fbfeSMogball       // having reached their pessimistic fixpoints.
19132a4e3fcSOleksandr "Alex" Zinenko       if (!callsites->allPredecessorsKnown() ||
19232a4e3fcSOleksandr "Alex" Zinenko           !getSolverConfig().isInterprocedural()) {
193de0ebc52SZhixun Tan         return setAllToEntryStates(argLattices);
19432a4e3fcSOleksandr "Alex" Zinenko       }
1959432fbfeSMogball       for (Operation *callsite : callsites->getKnownPredecessors()) {
1969432fbfeSMogball         auto call = cast<CallOpInterface>(callsite);
1979432fbfeSMogball         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
1984b3f251bSdonald chen           join(std::get<1>(it),
1994b3f251bSdonald chen                *getLatticeElementFor(getProgramPointBefore(block),
2004b3f251bSdonald chen                                      std::get<0>(it)));
2019432fbfeSMogball       }
2029432fbfeSMogball       return;
2039432fbfeSMogball     }
2049432fbfeSMogball 
2059432fbfeSMogball     // Check if the lattices can be determined from region control flow.
2069432fbfeSMogball     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
2074b3f251bSdonald chen       return visitRegionSuccessors(getProgramPointBefore(block), branch,
2084b3f251bSdonald chen                                    block->getParent(), argLattices);
2099432fbfeSMogball     }
2109432fbfeSMogball 
2119432fbfeSMogball     // Otherwise, we can't reason about the data-flow.
212ab701975SMogball     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
213ab701975SMogball                                             RegionSuccessor(block->getParent()),
214ab701975SMogball                                             argLattices, /*firstIndex=*/0);
2159432fbfeSMogball   }
2169432fbfeSMogball 
2179432fbfeSMogball   // Iterate over the predecessors of the non-entry block.
2189432fbfeSMogball   for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
2199432fbfeSMogball        it != e; ++it) {
2209432fbfeSMogball     Block *predecessor = *it;
2219432fbfeSMogball 
2229432fbfeSMogball     // If the edge from the predecessor block to the current block is not live,
2239432fbfeSMogball     // bail out.
2249432fbfeSMogball     auto *edgeExecutable =
225b6603e1bSdonald chen         getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
2269432fbfeSMogball     edgeExecutable->blockContentSubscribe(this);
2279432fbfeSMogball     if (!edgeExecutable->isLive())
2289432fbfeSMogball       continue;
2299432fbfeSMogball 
2309432fbfeSMogball     // Check if we can reason about the data-flow from the predecessor.
2319432fbfeSMogball     if (auto branch =
2329432fbfeSMogball             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
2339432fbfeSMogball       SuccessorOperands operands =
2349432fbfeSMogball           branch.getSuccessorOperands(it.getSuccessorIndex());
2358c258fdaSJakub Kuderski       for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
2368c258fdaSJakub Kuderski         if (Value operand = operands[idx]) {
2374b3f251bSdonald chen           join(lattice,
2384b3f251bSdonald chen                *getLatticeElementFor(getProgramPointBefore(block), operand));
2399432fbfeSMogball         } else {
240de0ebc52SZhixun Tan           // Conservatively consider internally produced arguments as entry
241de0ebc52SZhixun Tan           // points.
2428c258fdaSJakub Kuderski           setAllToEntryStates(lattice);
2439432fbfeSMogball         }
2449432fbfeSMogball       }
2459432fbfeSMogball     } else {
246de0ebc52SZhixun Tan       return setAllToEntryStates(argLattices);
2479432fbfeSMogball     }
2489432fbfeSMogball   }
2499432fbfeSMogball }
2509432fbfeSMogball 
251b2b7efb9SAlex Zinenko void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
2524b3f251bSdonald chen     ProgramPoint *point, RegionBranchOpInterface branch,
2534dd744acSMarkus Böck     RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
2549432fbfeSMogball   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
2559432fbfeSMogball   assert(predecessors->allPredecessorsKnown() &&
2569432fbfeSMogball          "unexpected unresolved region successors");
2579432fbfeSMogball 
2589432fbfeSMogball   for (Operation *op : predecessors->getKnownPredecessors()) {
2599432fbfeSMogball     // Get the incoming successor operands.
26022426110SRamkumar Ramachandra     std::optional<OperandRange> operands;
2619432fbfeSMogball 
2629432fbfeSMogball     // Check if the predecessor is the parent op.
2639432fbfeSMogball     if (op == branch) {
2644dd744acSMarkus Böck       operands = branch.getEntrySuccessorOperands(successor);
2659432fbfeSMogball       // Otherwise, try to deduce the operands from a region return-like op.
26610ae8ae8SMarkus Böck     } else if (auto regionTerminator =
26710ae8ae8SMarkus Böck                    dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
2684dd744acSMarkus Böck       operands = regionTerminator.getSuccessorOperands(successor);
2699432fbfeSMogball     }
2709432fbfeSMogball 
2719432fbfeSMogball     if (!operands) {
2729432fbfeSMogball       // We can't reason about the data-flow.
273de0ebc52SZhixun Tan       return setAllToEntryStates(lattices);
2749432fbfeSMogball     }
2759432fbfeSMogball 
2769432fbfeSMogball     ValueRange inputs = predecessors->getSuccessorInputs(op);
2779432fbfeSMogball     assert(inputs.size() == operands->size() &&
2789432fbfeSMogball            "expected the same number of successor inputs as operands");
2799432fbfeSMogball 
2809432fbfeSMogball     unsigned firstIndex = 0;
2819432fbfeSMogball     if (inputs.size() != lattices.size()) {
2824b3f251bSdonald chen       if (!point->isBlockStart()) {
283ab701975SMogball         if (!inputs.empty())
2845550c821STres Popp           firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
285ab701975SMogball         visitNonControlFlowArgumentsImpl(
286ab701975SMogball             branch,
287ab701975SMogball             RegionSuccessor(
288ab701975SMogball                 branch->getResults().slice(firstIndex, inputs.size())),
289ab701975SMogball             lattices, firstIndex);
290ab701975SMogball       } else {
291ab701975SMogball         if (!inputs.empty())
2925550c821STres Popp           firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
2934b3f251bSdonald chen         Region *region = point->getBlock()->getParent();
294ab701975SMogball         visitNonControlFlowArgumentsImpl(
295ab701975SMogball             branch,
296ab701975SMogball             RegionSuccessor(region, region->getArguments().slice(
297ab701975SMogball                                         firstIndex, inputs.size())),
298ab701975SMogball             lattices, firstIndex);
299ab701975SMogball       }
3009432fbfeSMogball     }
3019432fbfeSMogball 
3029432fbfeSMogball     for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
3039432fbfeSMogball       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
3049432fbfeSMogball   }
3059432fbfeSMogball }
3069432fbfeSMogball 
3079432fbfeSMogball const AbstractSparseLattice *
3084b3f251bSdonald chen AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point,
3099432fbfeSMogball                                                             Value value) {
3109432fbfeSMogball   AbstractSparseLattice *state = getLatticeElement(value);
3119432fbfeSMogball   addDependency(state, point);
3129432fbfeSMogball   return state;
3139432fbfeSMogball }
3149432fbfeSMogball 
315b2b7efb9SAlex Zinenko void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
3169432fbfeSMogball     ArrayRef<AbstractSparseLattice *> lattices) {
3179432fbfeSMogball   for (AbstractSparseLattice *lattice : lattices)
318de0ebc52SZhixun Tan     setToEntryState(lattice);
3199432fbfeSMogball }
3209432fbfeSMogball 
321b2b7efb9SAlex Zinenko void AbstractSparseForwardDataFlowAnalysis::join(
322b2b7efb9SAlex Zinenko     AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
3239432fbfeSMogball   propagateIfChanged(lhs, lhs->join(rhs));
3249432fbfeSMogball }
3254e98d611SMatthias Kramm 
3264e98d611SMatthias Kramm //===----------------------------------------------------------------------===//
3274e98d611SMatthias Kramm // AbstractSparseBackwardDataFlowAnalysis
3284e98d611SMatthias Kramm //===----------------------------------------------------------------------===//
3294e98d611SMatthias Kramm 
3304e98d611SMatthias Kramm AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
3314e98d611SMatthias Kramm     DataFlowSolver &solver, SymbolTableCollection &symbolTable)
3324e98d611SMatthias Kramm     : DataFlowAnalysis(solver), symbolTable(symbolTable) {
333b6603e1bSdonald chen   registerAnchorKind<CFGEdge>();
3344e98d611SMatthias Kramm }
3354e98d611SMatthias Kramm 
3364e98d611SMatthias Kramm LogicalResult
3374e98d611SMatthias Kramm AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
3384e98d611SMatthias Kramm   return initializeRecursively(top);
3394e98d611SMatthias Kramm }
3404e98d611SMatthias Kramm 
3414e98d611SMatthias Kramm LogicalResult
3424e98d611SMatthias Kramm AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
34315e915a4SIvan Butygin   if (failed(visitOperation(op)))
34415e915a4SIvan Butygin     return failure();
34515e915a4SIvan Butygin 
3464e98d611SMatthias Kramm   for (Region &region : op->getRegions()) {
3474e98d611SMatthias Kramm     for (Block &block : region) {
3484b3f251bSdonald chen       getOrCreate<Executable>(getProgramPointBefore(&block))
3494b3f251bSdonald chen           ->blockContentSubscribe(this);
3504e98d611SMatthias Kramm       // Initialize ops in reverse order, so we can do as much initial
3514e98d611SMatthias Kramm       // propagation as possible without having to go through the
3524e98d611SMatthias Kramm       // solver queue.
3534e98d611SMatthias Kramm       for (auto it = block.rbegin(); it != block.rend(); it++)
3544e98d611SMatthias Kramm         if (failed(initializeRecursively(&*it)))
3554e98d611SMatthias Kramm           return failure();
3564e98d611SMatthias Kramm     }
3574e98d611SMatthias Kramm   }
3584e98d611SMatthias Kramm   return success();
3594e98d611SMatthias Kramm }
3604e98d611SMatthias Kramm 
3614e98d611SMatthias Kramm LogicalResult
3624b3f251bSdonald chen AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
3634e98d611SMatthias Kramm   // For backward dataflow, we don't have to do any work for the blocks
3644e98d611SMatthias Kramm   // themselves. CFG edges between blocks are processed by the BranchOp
3654e98d611SMatthias Kramm   // logic in `visitOperation`, and entry blocks for functions are tied
3664e98d611SMatthias Kramm   // to the CallOp arguments by visitOperation.
3674b3f251bSdonald chen   if (point->isBlockStart())
3684e98d611SMatthias Kramm     return success();
3694b3f251bSdonald chen   return visitOperation(point->getPrevOp());
3704e98d611SMatthias Kramm }
3714e98d611SMatthias Kramm 
3724e98d611SMatthias Kramm SmallVector<AbstractSparseLattice *>
3734e98d611SMatthias Kramm AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
3744e98d611SMatthias Kramm   SmallVector<AbstractSparseLattice *> resultLattices;
3754e98d611SMatthias Kramm   resultLattices.reserve(values.size());
3764e98d611SMatthias Kramm   for (Value result : values) {
3774e98d611SMatthias Kramm     AbstractSparseLattice *resultLattice = getLatticeElement(result);
3784e98d611SMatthias Kramm     resultLattices.push_back(resultLattice);
3794e98d611SMatthias Kramm   }
3804e98d611SMatthias Kramm   return resultLattices;
3814e98d611SMatthias Kramm }
3824e98d611SMatthias Kramm 
3834e98d611SMatthias Kramm SmallVector<const AbstractSparseLattice *>
3844e98d611SMatthias Kramm AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
3854b3f251bSdonald chen     ProgramPoint *point, ValueRange values) {
3864e98d611SMatthias Kramm   SmallVector<const AbstractSparseLattice *> resultLattices;
3874e98d611SMatthias Kramm   resultLattices.reserve(values.size());
3884e98d611SMatthias Kramm   for (Value result : values) {
3894e98d611SMatthias Kramm     const AbstractSparseLattice *resultLattice =
3904e98d611SMatthias Kramm         getLatticeElementFor(point, result);
3914e98d611SMatthias Kramm     resultLattices.push_back(resultLattice);
3924e98d611SMatthias Kramm   }
3934e98d611SMatthias Kramm   return resultLattices;
3944e98d611SMatthias Kramm }
3954e98d611SMatthias Kramm 
3964e98d611SMatthias Kramm static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
3974e98d611SMatthias Kramm   return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
3984e98d611SMatthias Kramm }
3994e98d611SMatthias Kramm 
40015e915a4SIvan Butygin LogicalResult
40115e915a4SIvan Butygin AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
4024e98d611SMatthias Kramm   // If we're in a dead block, bail out.
4034b3f251bSdonald chen   if (op->getBlock() != nullptr &&
4044b3f251bSdonald chen       !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
40515e915a4SIvan Butygin     return success();
4064e98d611SMatthias Kramm 
4074e98d611SMatthias Kramm   SmallVector<AbstractSparseLattice *> operandLattices =
4084e98d611SMatthias Kramm       getLatticeElements(op->getOperands());
4094e98d611SMatthias Kramm   SmallVector<const AbstractSparseLattice *> resultLattices =
4104b3f251bSdonald chen       getLatticeElementsFor(getProgramPointAfter(op), op->getResults());
4114e98d611SMatthias Kramm 
4124e98d611SMatthias Kramm   // Block arguments of region branch operations flow back into the operands
4134e98d611SMatthias Kramm   // of the parent op
4144e98d611SMatthias Kramm   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
4154e98d611SMatthias Kramm     visitRegionSuccessors(branch, operandLattices);
41615e915a4SIvan Butygin     return success();
4174e98d611SMatthias Kramm   }
4184e98d611SMatthias Kramm 
4194e98d611SMatthias Kramm   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
4204e98d611SMatthias Kramm     // Block arguments of successor blocks flow back into our operands.
4214e98d611SMatthias Kramm 
4224e98d611SMatthias Kramm     // We remember all operands not forwarded to any block in a BitVector.
4234e98d611SMatthias Kramm     // We can't just cut out a range here, since the non-forwarded ops might
4244e98d611SMatthias Kramm     // be non-contiguous (if there's more than one successor).
4254e98d611SMatthias Kramm     BitVector unaccounted(op->getNumOperands(), true);
4264e98d611SMatthias Kramm 
4274e98d611SMatthias Kramm     for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
4284e98d611SMatthias Kramm       SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
4294e98d611SMatthias Kramm       OperandRange forwarded = successorOperands.getForwardedOperands();
4300fe37a75SAdrian Kuegel       if (!forwarded.empty()) {
4314e98d611SMatthias Kramm         MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
4324e98d611SMatthias Kramm             forwarded.getBeginOperandIndex(), forwarded.size());
4334e98d611SMatthias Kramm         for (OpOperand &operand : operands) {
4344e98d611SMatthias Kramm           unaccounted.reset(operand.getOperandNumber());
43522426110SRamkumar Ramachandra           if (std::optional<BlockArgument> blockArg =
4364e98d611SMatthias Kramm                   detail::getBranchSuccessorArgument(
4374e98d611SMatthias Kramm                       successorOperands, operand.getOperandNumber(), block)) {
4384e98d611SMatthias Kramm             meet(getLatticeElement(operand.get()),
4394b3f251bSdonald chen                  *getLatticeElementFor(getProgramPointAfter(op), *blockArg));
4404e98d611SMatthias Kramm           }
4414e98d611SMatthias Kramm         }
4424e98d611SMatthias Kramm       }
4434e98d611SMatthias Kramm     }
4444e98d611SMatthias Kramm     // Operands not forwarded to successor blocks are typically parameters
4454e98d611SMatthias Kramm     // of the branch operation itself (for example the boolean for if/else).
4464e98d611SMatthias Kramm     for (int index : unaccounted.set_bits()) {
4474e98d611SMatthias Kramm       OpOperand &operand = op->getOpOperand(index);
4484e98d611SMatthias Kramm       visitBranchOperand(operand);
4494e98d611SMatthias Kramm     }
45015e915a4SIvan Butygin     return success();
4514e98d611SMatthias Kramm   }
4524e98d611SMatthias Kramm 
453232f8eadSSrishti Srivastava   // For function calls, connect the arguments of the entry blocks to the
454232f8eadSSrishti Srivastava   // operands of the call op that are forwarded to these arguments.
4554e98d611SMatthias Kramm   if (auto call = dyn_cast<CallOpInterface>(op)) {
456d1cad229SHenrich Lauko     Operation *callableOp = call.resolveCallableInTable(&symbolTable);
4574e98d611SMatthias Kramm     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
458232f8eadSSrishti Srivastava       // Not all operands of a call op forward to arguments. Such operands are
459232f8eadSSrishti Srivastava       // stored in `unaccounted`.
460232f8eadSSrishti Srivastava       BitVector unaccounted(op->getNumOperands(), true);
461232f8eadSSrishti Srivastava 
46232a4e3fcSOleksandr "Alex" Zinenko       // If the call invokes an external function (or a function treated as
46332a4e3fcSOleksandr "Alex" Zinenko       // external due to config), defer to the corresponding extension hook.
46432a4e3fcSOleksandr "Alex" Zinenko       // By default, it just does `visitCallOperand` for all operands.
465232f8eadSSrishti Srivastava       OperandRange argOperands = call.getArgOperands();
466232f8eadSSrishti Srivastava       MutableArrayRef<OpOperand> argOpOperands =
467232f8eadSSrishti Srivastava           operandsToOpOperands(argOperands);
4684e98d611SMatthias Kramm       Region *region = callable.getCallableRegion();
46915e915a4SIvan Butygin       if (!region || region->empty() ||
47015e915a4SIvan Butygin           !getSolverConfig().isInterprocedural()) {
47115e915a4SIvan Butygin         visitExternalCallImpl(call, operandLattices, resultLattices);
47215e915a4SIvan Butygin         return success();
47315e915a4SIvan Butygin       }
47432a4e3fcSOleksandr "Alex" Zinenko 
47532a4e3fcSOleksandr "Alex" Zinenko       // Otherwise, propagate information from the entry point of the function
47632a4e3fcSOleksandr "Alex" Zinenko       // back to operands whenever possible.
4774e98d611SMatthias Kramm       Block &block = region->front();
478232f8eadSSrishti Srivastava       for (auto [blockArg, argOpOperand] :
479232f8eadSSrishti Srivastava            llvm::zip(block.getArguments(), argOpOperands)) {
480232f8eadSSrishti Srivastava         meet(getLatticeElement(argOpOperand.get()),
4814b3f251bSdonald chen              *getLatticeElementFor(getProgramPointAfter(op), blockArg));
482232f8eadSSrishti Srivastava         unaccounted.reset(argOpOperand.getOperandNumber());
4834e98d611SMatthias Kramm       }
48432a4e3fcSOleksandr "Alex" Zinenko 
485232f8eadSSrishti Srivastava       // Handle the operands of the call op that aren't forwarded to any
486232f8eadSSrishti Srivastava       // arguments.
487232f8eadSSrishti Srivastava       for (int index : unaccounted.set_bits()) {
488232f8eadSSrishti Srivastava         OpOperand &opOperand = op->getOpOperand(index);
489232f8eadSSrishti Srivastava         visitCallOperand(opOperand);
490232f8eadSSrishti Srivastava       }
49115e915a4SIvan Butygin       return success();
4924e98d611SMatthias Kramm     }
4934e98d611SMatthias Kramm   }
4944e98d611SMatthias Kramm 
495a9ab845cSSrishti Srivastava   // When the region of an op implementing `RegionBranchOpInterface` has a
496a9ab845cSSrishti Srivastava   // terminator implementing `RegionBranchTerminatorOpInterface` or a
497a9ab845cSSrishti Srivastava   // return-like terminator, the region's successors' arguments flow back into
498a9ab845cSSrishti Srivastava   // the "successor operands" of this terminator.
499a9ab845cSSrishti Srivastava   //
500a9ab845cSSrishti Srivastava   // A successor operand with respect to an op implementing
501a9ab845cSSrishti Srivastava   // `RegionBranchOpInterface` is an operand that is forwarded to a region
502a9ab845cSSrishti Srivastava   // successor's input. There are two types of successor operands: the operands
503a9ab845cSSrishti Srivastava   // of this op itself and the operands of the terminators of the regions of
504a9ab845cSSrishti Srivastava   // this op.
50510ae8ae8SMarkus Böck   if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
5064e98d611SMatthias Kramm     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
50710ae8ae8SMarkus Böck       visitRegionSuccessorsFromTerminator(terminator, branch);
50815e915a4SIvan Butygin       return success();
5094e98d611SMatthias Kramm     }
5104e98d611SMatthias Kramm   }
5114e98d611SMatthias Kramm 
5124e98d611SMatthias Kramm   if (op->hasTrait<OpTrait::ReturnLike>()) {
5134e98d611SMatthias Kramm     // Going backwards, the operands of the return are derived from the
5144e98d611SMatthias Kramm     // results of all CallOps calling this CallableOp.
5154e98d611SMatthias Kramm     if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
5164b3f251bSdonald chen       const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
5174b3f251bSdonald chen           getProgramPointAfter(op), getProgramPointAfter(callable));
5184e98d611SMatthias Kramm       if (callsites->allPredecessorsKnown()) {
5194e98d611SMatthias Kramm         for (Operation *call : callsites->getKnownPredecessors()) {
5204e98d611SMatthias Kramm           SmallVector<const AbstractSparseLattice *> callResultLattices =
5214b3f251bSdonald chen               getLatticeElementsFor(getProgramPointAfter(op),
5224b3f251bSdonald chen                                     call->getResults());
5234e98d611SMatthias Kramm           for (auto [op, result] :
5244e98d611SMatthias Kramm                llvm::zip(operandLattices, callResultLattices))
5254e98d611SMatthias Kramm             meet(op, *result);
5264e98d611SMatthias Kramm         }
5274e98d611SMatthias Kramm       } else {
5284e98d611SMatthias Kramm         // If we don't know all the callers, we can't know where the
5294e98d611SMatthias Kramm         // returned values go. Note that, in particular, this will trigger
5304e98d611SMatthias Kramm         // for the return ops of any public functions.
5314e98d611SMatthias Kramm         setAllToExitStates(operandLattices);
5324e98d611SMatthias Kramm       }
53315e915a4SIvan Butygin       return success();
5344e98d611SMatthias Kramm     }
5354e98d611SMatthias Kramm   }
5364e98d611SMatthias Kramm 
53715e915a4SIvan Butygin   return visitOperationImpl(op, operandLattices, resultLattices);
5384e98d611SMatthias Kramm }
5394e98d611SMatthias Kramm 
5404e98d611SMatthias Kramm void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
5414e98d611SMatthias Kramm     RegionBranchOpInterface branch,
5424e98d611SMatthias Kramm     ArrayRef<AbstractSparseLattice *> operandLattices) {
5434e98d611SMatthias Kramm   Operation *op = branch.getOperation();
5444e98d611SMatthias Kramm   SmallVector<RegionSuccessor> successors;
5454e98d611SMatthias Kramm   SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
546138df298SMarkus Böck   branch.getEntrySuccessorRegions(operands, successors);
5474e98d611SMatthias Kramm 
5484e98d611SMatthias Kramm   // All operands not forwarded to any successor. This set can be non-contiguous
5494e98d611SMatthias Kramm   // in the presence of multiple successors.
5504e98d611SMatthias Kramm   BitVector unaccounted(op->getNumOperands(), true);
5514e98d611SMatthias Kramm 
5524e98d611SMatthias Kramm   for (RegionSuccessor &successor : successors) {
5534dd744acSMarkus Böck     OperandRange operands = branch.getEntrySuccessorOperands(successor);
5544e98d611SMatthias Kramm     MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
5554e98d611SMatthias Kramm     ValueRange inputs = successor.getSuccessorInputs();
5564e98d611SMatthias Kramm     for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
5574b3f251bSdonald chen       meet(getLatticeElement(operand.get()),
5584b3f251bSdonald chen            *getLatticeElementFor(getProgramPointAfter(op), input));
5594e98d611SMatthias Kramm       unaccounted.reset(operand.getOperandNumber());
5604e98d611SMatthias Kramm     }
5614e98d611SMatthias Kramm   }
5624e98d611SMatthias Kramm   // All operands not forwarded to regions are typically parameters of the
5634e98d611SMatthias Kramm   // branch operation itself (for example the boolean for if/else).
5644e98d611SMatthias Kramm   for (int index : unaccounted.set_bits()) {
5654e98d611SMatthias Kramm     visitBranchOperand(op->getOpOperand(index));
5664e98d611SMatthias Kramm   }
5674e98d611SMatthias Kramm }
5684e98d611SMatthias Kramm 
569a9ab845cSSrishti Srivastava void AbstractSparseBackwardDataFlowAnalysis::
57010ae8ae8SMarkus Böck     visitRegionSuccessorsFromTerminator(
57110ae8ae8SMarkus Böck         RegionBranchTerminatorOpInterface terminator,
572a9ab845cSSrishti Srivastava         RegionBranchOpInterface branch) {
57310ae8ae8SMarkus Böck   assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
57410ae8ae8SMarkus Böck          "expected a `RegionBranchTerminatorOpInterface` op");
575a9ab845cSSrishti Srivastava   assert(terminator->getParentOp() == branch.getOperation() &&
576a9ab845cSSrishti Srivastava          "expected `branch` to be the parent op of `terminator`");
577a9ab845cSSrishti Srivastava 
578a9ab845cSSrishti Srivastava   SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
579a9ab845cSSrishti Srivastava                                            nullptr);
580a9ab845cSSrishti Srivastava   SmallVector<RegionSuccessor> successors;
581138df298SMarkus Böck   terminator.getSuccessorRegions(operandAttributes, successors);
582a9ab845cSSrishti Srivastava   // All operands not forwarded to any successor. This set can be
583a9ab845cSSrishti Srivastava   // non-contiguous in the presence of multiple successors.
584a9ab845cSSrishti Srivastava   BitVector unaccounted(terminator->getNumOperands(), true);
585a9ab845cSSrishti Srivastava 
586a9ab845cSSrishti Srivastava   for (const RegionSuccessor &successor : successors) {
587a9ab845cSSrishti Srivastava     ValueRange inputs = successor.getSuccessorInputs();
5884dd744acSMarkus Böck     OperandRange operands = terminator.getSuccessorOperands(successor);
589a9ab845cSSrishti Srivastava     MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
590a9ab845cSSrishti Srivastava     for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
591a9ab845cSSrishti Srivastava       meet(getLatticeElement(opOperand.get()),
5924b3f251bSdonald chen            *getLatticeElementFor(getProgramPointAfter(terminator), input));
593a9ab845cSSrishti Srivastava       unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
594a9ab845cSSrishti Srivastava     }
595a9ab845cSSrishti Srivastava   }
596a9ab845cSSrishti Srivastava   // Visit operands of the branch op not forwarded to the next region.
597a9ab845cSSrishti Srivastava   // (Like e.g. the boolean of `scf.conditional`)
598a9ab845cSSrishti Srivastava   for (int index : unaccounted.set_bits()) {
599a9ab845cSSrishti Srivastava     visitBranchOperand(terminator->getOpOperand(index));
600a9ab845cSSrishti Srivastava   }
601a9ab845cSSrishti Srivastava }
602a9ab845cSSrishti Srivastava 
6034e98d611SMatthias Kramm const AbstractSparseLattice *
6044b3f251bSdonald chen AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
6054b3f251bSdonald chen     ProgramPoint *point, Value value) {
6064e98d611SMatthias Kramm   AbstractSparseLattice *state = getLatticeElement(value);
6074e98d611SMatthias Kramm   addDependency(state, point);
6084e98d611SMatthias Kramm   return state;
6094e98d611SMatthias Kramm }
6104e98d611SMatthias Kramm 
6114e98d611SMatthias Kramm void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
6124e98d611SMatthias Kramm     ArrayRef<AbstractSparseLattice *> lattices) {
6134e98d611SMatthias Kramm   for (AbstractSparseLattice *lattice : lattices)
6144e98d611SMatthias Kramm     setToExitState(lattice);
6154e98d611SMatthias Kramm }
6164e98d611SMatthias Kramm 
6174e98d611SMatthias Kramm void AbstractSparseBackwardDataFlowAnalysis::meet(
6184e98d611SMatthias Kramm     AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
6194e98d611SMatthias Kramm   propagateIfChanged(lhs, lhs->meet(rhs));
6204e98d611SMatthias Kramm }
621