//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include #include using namespace mlir; using namespace mlir::dataflow; //===----------------------------------------------------------------------===// // AbstractSparseLattice //===----------------------------------------------------------------------===// void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { AnalysisState::onUpdate(solver); // Push all users of the value to the queue. for (Operation *user : cast(anchor).getUsers()) for (DataFlowAnalysis *analysis : useDefSubscribers) solver->enqueue({solver->getProgramPointAfter(user), analysis}); } //===----------------------------------------------------------------------===// // AbstractSparseForwardDataFlowAnalysis //===----------------------------------------------------------------------===// AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis( DataFlowSolver &solver) : DataFlowAnalysis(solver) { registerAnchorKind(); } LogicalResult AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { // Mark the entry block arguments as having reached their pessimistic // fixpoints. for (Region ®ion : top->getRegions()) { if (region.empty()) continue; for (Value argument : region.front().getArguments()) setToEntryState(getLatticeElement(argument)); } return initializeRecursively(top); } LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { // Initialize the analysis by visiting every owner of an SSA value (all // operations and blocks). if (failed(visitOperation(op))) return failure(); for (Region ®ion : op->getRegions()) { for (Block &block : region) { getOrCreate(getProgramPointBefore(&block)) ->blockContentSubscribe(this); visitBlock(&block); for (Operation &op : block) if (failed(initializeRecursively(&op))) return failure(); } } return success(); } LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) { if (!point->isBlockStart()) return visitOperation(point->getPrevOp()); visitBlock(point->getBlock()); return success(); } LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // Exit early on operations with no results. if (op->getNumResults() == 0) return success(); // If the containing block is not executable, bail out. if (op->getBlock() != nullptr && !getOrCreate(getProgramPointBefore(op->getBlock()))->isLive()) return success(); // Get the result lattices. SmallVector resultLattices; resultLattices.reserve(op->getNumResults()); for (Value result : op->getResults()) { AbstractSparseLattice *resultLattice = getLatticeElement(result); resultLattices.push_back(resultLattice); } // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast(op)) { visitRegionSuccessors(getProgramPointAfter(branch), branch, /*successor=*/RegionBranchPoint::parent(), resultLattices); return success(); } // Grab the lattice elements of the operands. SmallVector operandLattices; operandLattices.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { AbstractSparseLattice *operandLattice = getLatticeElement(operand); operandLattice->useDefSubscribe(this); operandLattices.push_back(operandLattice); } if (auto call = dyn_cast(op)) { // If the call operation is to an external function, attempt to infer the // results from the call arguments. auto callable = dyn_cast_if_present(call.resolveCallable()); if (!getSolverConfig().isInterprocedural() || (callable && !callable.getCallableRegion())) { visitExternalCallImpl(call, operandLattices, resultLattices); return success(); } // Otherwise, the results of a call operation are determined by the // callgraph. const auto *predecessors = getOrCreateFor( getProgramPointAfter(op), getProgramPointAfter(call)); // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) { setAllToEntryStates(resultLattices); return success(); } for (Operation *predecessor : predecessors->getKnownPredecessors()) for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) join(resLattice, *getLatticeElementFor(getProgramPointAfter(op), operand)); return success(); } // Invoke the operation transfer function. return visitOperationImpl(op, operandLattices, resultLattices); } void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { // Exit early on blocks with no arguments. if (block->getNumArguments() == 0) return; // If the block is not executable, bail out. if (!getOrCreate(getProgramPointBefore(block))->isLive()) return; // Get the argument lattices. SmallVector argLattices; argLattices.reserve(block->getNumArguments()); for (BlockArgument argument : block->getArguments()) { AbstractSparseLattice *argLattice = getLatticeElement(argument); argLattices.push_back(argLattice); } // The argument lattices of entry blocks are set by region control-flow or the // callgraph. if (block->isEntryBlock()) { // Check if this block is the entry block of a callable region. auto callable = dyn_cast(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { const auto *callsites = getOrCreateFor( getProgramPointBefore(block), getProgramPointAfter(callable)); // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown() || !getSolverConfig().isInterprocedural()) { return setAllToEntryStates(argLattices); } for (Operation *callsite : callsites->getKnownPredecessors()) { auto call = cast(callsite); for (auto it : llvm::zip(call.getArgOperands(), argLattices)) join(std::get<1>(it), *getLatticeElementFor(getProgramPointBefore(block), std::get<0>(it))); } return; } // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast(block->getParentOp())) { return visitRegionSuccessors(getProgramPointBefore(block), branch, block->getParent(), argLattices); } // Otherwise, we can't reason about the data-flow. return visitNonControlFlowArgumentsImpl(block->getParentOp(), RegionSuccessor(block->getParent()), argLattices, /*firstIndex=*/0); } // Iterate over the predecessors of the non-entry block. for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { Block *predecessor = *it; // If the edge from the predecessor block to the current block is not live, // bail out. auto *edgeExecutable = getOrCreate(getLatticeAnchor(predecessor, block)); edgeExecutable->blockContentSubscribe(this); if (!edgeExecutable->isLive()) continue; // Check if we can reason about the data-flow from the predecessor. if (auto branch = dyn_cast(predecessor->getTerminator())) { SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex()); for (auto [idx, lattice] : llvm::enumerate(argLattices)) { if (Value operand = operands[idx]) { join(lattice, *getLatticeElementFor(getProgramPointBefore(block), operand)); } else { // Conservatively consider internally produced arguments as entry // points. setAllToEntryStates(lattice); } } } else { return setAllToEntryStates(argLattices); } } } void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef lattices) { const auto *predecessors = getOrCreateFor(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); for (Operation *op : predecessors->getKnownPredecessors()) { // Get the incoming successor operands. std::optional operands; // Check if the predecessor is the parent op. if (op == branch) { operands = branch.getEntrySuccessorOperands(successor); // Otherwise, try to deduce the operands from a region return-like op. } else if (auto regionTerminator = dyn_cast(op)) { operands = regionTerminator.getSuccessorOperands(successor); } if (!operands) { // We can't reason about the data-flow. return setAllToEntryStates(lattices); } ValueRange inputs = predecessors->getSuccessorInputs(op); assert(inputs.size() == operands->size() && "expected the same number of successor inputs as operands"); unsigned firstIndex = 0; if (inputs.size() != lattices.size()) { if (!point->isBlockStart()) { if (!inputs.empty()) firstIndex = cast(inputs.front()).getResultNumber(); visitNonControlFlowArgumentsImpl( branch, RegionSuccessor( branch->getResults().slice(firstIndex, inputs.size())), lattices, firstIndex); } else { if (!inputs.empty()) firstIndex = cast(inputs.front()).getArgNumber(); Region *region = point->getBlock()->getParent(); visitNonControlFlowArgumentsImpl( branch, RegionSuccessor(region, region->getArguments().slice( firstIndex, inputs.size())), lattices, firstIndex); } } for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); } } const AbstractSparseLattice * AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point, Value value) { AbstractSparseLattice *state = getLatticeElement(value); addDependency(state, point); return state; } void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates( ArrayRef lattices) { for (AbstractSparseLattice *lattice : lattices) setToEntryState(lattice); } void AbstractSparseForwardDataFlowAnalysis::join( AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { propagateIfChanged(lhs, lhs->join(rhs)); } //===----------------------------------------------------------------------===// // AbstractSparseBackwardDataFlowAnalysis //===----------------------------------------------------------------------===// AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis( DataFlowSolver &solver, SymbolTableCollection &symbolTable) : DataFlowAnalysis(solver), symbolTable(symbolTable) { registerAnchorKind(); } LogicalResult AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) { return initializeRecursively(top); } LogicalResult AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { if (failed(visitOperation(op))) return failure(); for (Region ®ion : op->getRegions()) { for (Block &block : region) { getOrCreate(getProgramPointBefore(&block)) ->blockContentSubscribe(this); // Initialize ops in reverse order, so we can do as much initial // propagation as possible without having to go through the // solver queue. for (auto it = block.rbegin(); it != block.rend(); it++) if (failed(initializeRecursively(&*it))) return failure(); } } return success(); } LogicalResult AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { // For backward dataflow, we don't have to do any work for the blocks // themselves. CFG edges between blocks are processed by the BranchOp // logic in `visitOperation`, and entry blocks for functions are tied // to the CallOp arguments by visitOperation. if (point->isBlockStart()) return success(); return visitOperation(point->getPrevOp()); } SmallVector AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { SmallVector resultLattices; resultLattices.reserve(values.size()); for (Value result : values) { AbstractSparseLattice *resultLattice = getLatticeElement(result); resultLattices.push_back(resultLattice); } return resultLattices; } SmallVector AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( ProgramPoint *point, ValueRange values) { SmallVector resultLattices; resultLattices.reserve(values.size()); for (Value result : values) { const AbstractSparseLattice *resultLattice = getLatticeElementFor(point, result); resultLattices.push_back(resultLattice); } return resultLattices; } static MutableArrayRef operandsToOpOperands(OperandRange &operands) { return MutableArrayRef(operands.getBase(), operands.size()); } LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // If we're in a dead block, bail out. if (op->getBlock() != nullptr && !getOrCreate(getProgramPointBefore(op->getBlock()))->isLive()) return success(); SmallVector operandLattices = getLatticeElements(op->getOperands()); SmallVector resultLattices = getLatticeElementsFor(getProgramPointAfter(op), op->getResults()); // Block arguments of region branch operations flow back into the operands // of the parent op if (auto branch = dyn_cast(op)) { visitRegionSuccessors(branch, operandLattices); return success(); } if (auto branch = dyn_cast(op)) { // Block arguments of successor blocks flow back into our operands. // We remember all operands not forwarded to any block in a BitVector. // We can't just cut out a range here, since the non-forwarded ops might // be non-contiguous (if there's more than one successor). BitVector unaccounted(op->getNumOperands(), true); for (auto [index, block] : llvm::enumerate(op->getSuccessors())) { SuccessorOperands successorOperands = branch.getSuccessorOperands(index); OperandRange forwarded = successorOperands.getForwardedOperands(); if (!forwarded.empty()) { MutableArrayRef operands = op->getOpOperands().slice( forwarded.getBeginOperandIndex(), forwarded.size()); for (OpOperand &operand : operands) { unaccounted.reset(operand.getOperandNumber()); if (std::optional blockArg = detail::getBranchSuccessorArgument( successorOperands, operand.getOperandNumber(), block)) { meet(getLatticeElement(operand.get()), *getLatticeElementFor(getProgramPointAfter(op), *blockArg)); } } } } // Operands not forwarded to successor blocks are typically parameters // of the branch operation itself (for example the boolean for if/else). for (int index : unaccounted.set_bits()) { OpOperand &operand = op->getOpOperand(index); visitBranchOperand(operand); } return success(); } // For function calls, connect the arguments of the entry blocks to the // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast(op)) { Operation *callableOp = call.resolveCallableInTable(&symbolTable); if (auto callable = dyn_cast_or_null(callableOp)) { // Not all operands of a call op forward to arguments. Such operands are // stored in `unaccounted`. BitVector unaccounted(op->getNumOperands(), true); // If the call invokes an external function (or a function treated as // external due to config), defer to the corresponding extension hook. // By default, it just does `visitCallOperand` for all operands. OperandRange argOperands = call.getArgOperands(); MutableArrayRef argOpOperands = operandsToOpOperands(argOperands); Region *region = callable.getCallableRegion(); if (!region || region->empty() || !getSolverConfig().isInterprocedural()) { visitExternalCallImpl(call, operandLattices, resultLattices); return success(); } // Otherwise, propagate information from the entry point of the function // back to operands whenever possible. Block &block = region->front(); for (auto [blockArg, argOpOperand] : llvm::zip(block.getArguments(), argOpOperands)) { meet(getLatticeElement(argOpOperand.get()), *getLatticeElementFor(getProgramPointAfter(op), blockArg)); unaccounted.reset(argOpOperand.getOperandNumber()); } // Handle the operands of the call op that aren't forwarded to any // arguments. for (int index : unaccounted.set_bits()) { OpOperand &opOperand = op->getOpOperand(index); visitCallOperand(opOperand); } return success(); } } // When the region of an op implementing `RegionBranchOpInterface` has a // terminator implementing `RegionBranchTerminatorOpInterface` or a // return-like terminator, the region's successors' arguments flow back into // the "successor operands" of this terminator. // // A successor operand with respect to an op implementing // `RegionBranchOpInterface` is an operand that is forwarded to a region // successor's input. There are two types of successor operands: the operands // of this op itself and the operands of the terminators of the regions of // this op. if (auto terminator = dyn_cast(op)) { if (auto branch = dyn_cast(op->getParentOp())) { visitRegionSuccessorsFromTerminator(terminator, branch); return success(); } } if (op->hasTrait()) { // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. if (auto callable = dyn_cast(op->getParentOp())) { const PredecessorState *callsites = getOrCreateFor( getProgramPointAfter(op), getProgramPointAfter(callable)); if (callsites->allPredecessorsKnown()) { for (Operation *call : callsites->getKnownPredecessors()) { SmallVector callResultLattices = getLatticeElementsFor(getProgramPointAfter(op), call->getResults()); for (auto [op, result] : llvm::zip(operandLattices, callResultLattices)) meet(op, *result); } } else { // If we don't know all the callers, we can't know where the // returned values go. Note that, in particular, this will trigger // for the return ops of any public functions. setAllToExitStates(operandLattices); } return success(); } } return visitOperationImpl(op, operandLattices, resultLattices); } void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( RegionBranchOpInterface branch, ArrayRef operandLattices) { Operation *op = branch.getOperation(); SmallVector successors; SmallVector operands(op->getNumOperands(), nullptr); branch.getEntrySuccessorRegions(operands, successors); // All operands not forwarded to any successor. This set can be non-contiguous // in the presence of multiple successors. BitVector unaccounted(op->getNumOperands(), true); for (RegionSuccessor &successor : successors) { OperandRange operands = branch.getEntrySuccessorOperands(successor); MutableArrayRef opoperands = operandsToOpOperands(operands); ValueRange inputs = successor.getSuccessorInputs(); for (auto [operand, input] : llvm::zip(opoperands, inputs)) { meet(getLatticeElement(operand.get()), *getLatticeElementFor(getProgramPointAfter(op), input)); unaccounted.reset(operand.getOperandNumber()); } } // All operands not forwarded to regions are typically parameters of the // branch operation itself (for example the boolean for if/else). for (int index : unaccounted.set_bits()) { visitBranchOperand(op->getOpOperand(index)); } } void AbstractSparseBackwardDataFlowAnalysis:: visitRegionSuccessorsFromTerminator( RegionBranchTerminatorOpInterface terminator, RegionBranchOpInterface branch) { assert(isa(terminator) && "expected a `RegionBranchTerminatorOpInterface` op"); assert(terminator->getParentOp() == branch.getOperation() && "expected `branch` to be the parent op of `terminator`"); SmallVector operandAttributes(terminator->getNumOperands(), nullptr); SmallVector successors; terminator.getSuccessorRegions(operandAttributes, successors); // All operands not forwarded to any successor. This set can be // non-contiguous in the presence of multiple successors. BitVector unaccounted(terminator->getNumOperands(), true); for (const RegionSuccessor &successor : successors) { ValueRange inputs = successor.getSuccessorInputs(); OperandRange operands = terminator.getSuccessorOperands(successor); MutableArrayRef opOperands = operandsToOpOperands(operands); for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { meet(getLatticeElement(opOperand.get()), *getLatticeElementFor(getProgramPointAfter(terminator), input)); unaccounted.reset(const_cast(opOperand).getOperandNumber()); } } // Visit operands of the branch op not forwarded to the next region. // (Like e.g. the boolean of `scf.conditional`) for (int index : unaccounted.set_bits()) { visitBranchOperand(terminator->getOpOperand(index)); } } const AbstractSparseLattice * AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor( ProgramPoint *point, Value value) { AbstractSparseLattice *state = getLatticeElement(value); addDependency(state, point); return state; } void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( ArrayRef lattices) { for (AbstractSparseLattice *lattice : lattices) setToExitState(lattice); } void AbstractSparseBackwardDataFlowAnalysis::meet( AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { propagateIfChanged(lhs, lhs->meet(rhs)); }