1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 11 #include "mlir/Analysis/DataFlowFramework.h" 12 #include "mlir/Interfaces/CallInterfaces.h" 13 14 using namespace mlir; 15 using namespace mlir::dataflow; 16 17 //===----------------------------------------------------------------------===// 18 // AbstractSparseLattice 19 //===----------------------------------------------------------------------===// 20 21 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { 22 AnalysisState::onUpdate(solver); 23 24 // Push all users of the value to the queue. 25 for (Operation *user : point.get<Value>().getUsers()) 26 for (DataFlowAnalysis *analysis : useDefSubscribers) 27 solver->enqueue({user, analysis}); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // AbstractSparseForwardDataFlowAnalysis 32 //===----------------------------------------------------------------------===// 33 34 AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis( 35 DataFlowSolver &solver) 36 : DataFlowAnalysis(solver) { 37 registerPointKind<CFGEdge>(); 38 } 39 40 LogicalResult 41 AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { 42 // Mark the entry block arguments as having reached their pessimistic 43 // fixpoints. 44 for (Region ®ion : top->getRegions()) { 45 if (region.empty()) 46 continue; 47 for (Value argument : region.front().getArguments()) 48 setToEntryState(getLatticeElement(argument)); 49 } 50 51 return initializeRecursively(top); 52 } 53 54 LogicalResult 55 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { 56 // Initialize the analysis by visiting every owner of an SSA value (all 57 // operations and blocks). 58 visitOperation(op); 59 for (Region ®ion : op->getRegions()) { 60 for (Block &block : region) { 61 getOrCreate<Executable>(&block)->blockContentSubscribe(this); 62 visitBlock(&block); 63 for (Operation &op : block) 64 if (failed(initializeRecursively(&op))) 65 return failure(); 66 } 67 } 68 69 return success(); 70 } 71 72 LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) { 73 if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) 74 visitOperation(op); 75 else if (Block *block = llvm::dyn_cast_if_present<Block *>(point)) 76 visitBlock(block); 77 else 78 return failure(); 79 return success(); 80 } 81 82 void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { 83 // Exit early on operations with no results. 84 if (op->getNumResults() == 0) 85 return; 86 87 // If the containing block is not executable, bail out. 88 if (!getOrCreate<Executable>(op->getBlock())->isLive()) 89 return; 90 91 // Get the result lattices. 92 SmallVector<AbstractSparseLattice *> resultLattices; 93 resultLattices.reserve(op->getNumResults()); 94 for (Value result : op->getResults()) { 95 AbstractSparseLattice *resultLattice = getLatticeElement(result); 96 resultLattices.push_back(resultLattice); 97 } 98 99 // The results of a region branch operation are determined by control-flow. 100 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { 101 return visitRegionSuccessors({branch}, branch, 102 /*successor=*/RegionBranchPoint::parent(), 103 resultLattices); 104 } 105 106 // The results of a call operation are determined by the callgraph. 107 if (auto call = dyn_cast<CallOpInterface>(op)) { 108 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); 109 // If not all return sites are known, then conservatively assume we can't 110 // reason about the data-flow. 111 if (!predecessors->allPredecessorsKnown()) 112 return setAllToEntryStates(resultLattices); 113 for (Operation *predecessor : predecessors->getKnownPredecessors()) 114 for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) 115 join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); 116 return; 117 } 118 119 // Grab the lattice elements of the operands. 120 SmallVector<const AbstractSparseLattice *> operandLattices; 121 operandLattices.reserve(op->getNumOperands()); 122 for (Value operand : op->getOperands()) { 123 AbstractSparseLattice *operandLattice = getLatticeElement(operand); 124 operandLattice->useDefSubscribe(this); 125 operandLattices.push_back(operandLattice); 126 } 127 128 // Invoke the operation transfer function. 129 visitOperationImpl(op, operandLattices, resultLattices); 130 } 131 132 void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { 133 // Exit early on blocks with no arguments. 134 if (block->getNumArguments() == 0) 135 return; 136 137 // If the block is not executable, bail out. 138 if (!getOrCreate<Executable>(block)->isLive()) 139 return; 140 141 // Get the argument lattices. 142 SmallVector<AbstractSparseLattice *> argLattices; 143 argLattices.reserve(block->getNumArguments()); 144 for (BlockArgument argument : block->getArguments()) { 145 AbstractSparseLattice *argLattice = getLatticeElement(argument); 146 argLattices.push_back(argLattice); 147 } 148 149 // The argument lattices of entry blocks are set by region control-flow or the 150 // callgraph. 151 if (block->isEntryBlock()) { 152 // Check if this block is the entry block of a callable region. 153 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); 154 if (callable && callable.getCallableRegion() == block->getParent()) { 155 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); 156 // If not all callsites are known, conservatively mark all lattices as 157 // having reached their pessimistic fixpoints. 158 if (!callsites->allPredecessorsKnown()) 159 return setAllToEntryStates(argLattices); 160 for (Operation *callsite : callsites->getKnownPredecessors()) { 161 auto call = cast<CallOpInterface>(callsite); 162 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) 163 join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); 164 } 165 return; 166 } 167 168 // Check if the lattices can be determined from region control flow. 169 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { 170 return visitRegionSuccessors(block, branch, block->getParent(), 171 argLattices); 172 } 173 174 // Otherwise, we can't reason about the data-flow. 175 return visitNonControlFlowArgumentsImpl(block->getParentOp(), 176 RegionSuccessor(block->getParent()), 177 argLattices, /*firstIndex=*/0); 178 } 179 180 // Iterate over the predecessors of the non-entry block. 181 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); 182 it != e; ++it) { 183 Block *predecessor = *it; 184 185 // If the edge from the predecessor block to the current block is not live, 186 // bail out. 187 auto *edgeExecutable = 188 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block)); 189 edgeExecutable->blockContentSubscribe(this); 190 if (!edgeExecutable->isLive()) 191 continue; 192 193 // Check if we can reason about the data-flow from the predecessor. 194 if (auto branch = 195 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) { 196 SuccessorOperands operands = 197 branch.getSuccessorOperands(it.getSuccessorIndex()); 198 for (auto [idx, lattice] : llvm::enumerate(argLattices)) { 199 if (Value operand = operands[idx]) { 200 join(lattice, *getLatticeElementFor(block, operand)); 201 } else { 202 // Conservatively consider internally produced arguments as entry 203 // points. 204 setAllToEntryStates(lattice); 205 } 206 } 207 } else { 208 return setAllToEntryStates(argLattices); 209 } 210 } 211 } 212 213 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( 214 ProgramPoint point, RegionBranchOpInterface branch, 215 RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) { 216 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); 217 assert(predecessors->allPredecessorsKnown() && 218 "unexpected unresolved region successors"); 219 220 for (Operation *op : predecessors->getKnownPredecessors()) { 221 // Get the incoming successor operands. 222 std::optional<OperandRange> operands; 223 224 // Check if the predecessor is the parent op. 225 if (op == branch) { 226 operands = branch.getEntrySuccessorOperands(successor); 227 // Otherwise, try to deduce the operands from a region return-like op. 228 } else if (auto regionTerminator = 229 dyn_cast<RegionBranchTerminatorOpInterface>(op)) { 230 operands = regionTerminator.getSuccessorOperands(successor); 231 } 232 233 if (!operands) { 234 // We can't reason about the data-flow. 235 return setAllToEntryStates(lattices); 236 } 237 238 ValueRange inputs = predecessors->getSuccessorInputs(op); 239 assert(inputs.size() == operands->size() && 240 "expected the same number of successor inputs as operands"); 241 242 unsigned firstIndex = 0; 243 if (inputs.size() != lattices.size()) { 244 if (llvm::dyn_cast_if_present<Operation *>(point)) { 245 if (!inputs.empty()) 246 firstIndex = cast<OpResult>(inputs.front()).getResultNumber(); 247 visitNonControlFlowArgumentsImpl( 248 branch, 249 RegionSuccessor( 250 branch->getResults().slice(firstIndex, inputs.size())), 251 lattices, firstIndex); 252 } else { 253 if (!inputs.empty()) 254 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber(); 255 Region *region = point.get<Block *>()->getParent(); 256 visitNonControlFlowArgumentsImpl( 257 branch, 258 RegionSuccessor(region, region->getArguments().slice( 259 firstIndex, inputs.size())), 260 lattices, firstIndex); 261 } 262 } 263 264 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) 265 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); 266 } 267 } 268 269 const AbstractSparseLattice * 270 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, 271 Value value) { 272 AbstractSparseLattice *state = getLatticeElement(value); 273 addDependency(state, point); 274 return state; 275 } 276 277 void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates( 278 ArrayRef<AbstractSparseLattice *> lattices) { 279 for (AbstractSparseLattice *lattice : lattices) 280 setToEntryState(lattice); 281 } 282 283 void AbstractSparseForwardDataFlowAnalysis::join( 284 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { 285 propagateIfChanged(lhs, lhs->join(rhs)); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // AbstractSparseBackwardDataFlowAnalysis 290 //===----------------------------------------------------------------------===// 291 292 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis( 293 DataFlowSolver &solver, SymbolTableCollection &symbolTable) 294 : DataFlowAnalysis(solver), symbolTable(symbolTable) { 295 registerPointKind<CFGEdge>(); 296 } 297 298 LogicalResult 299 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) { 300 return initializeRecursively(top); 301 } 302 303 LogicalResult 304 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { 305 visitOperation(op); 306 for (Region ®ion : op->getRegions()) { 307 for (Block &block : region) { 308 getOrCreate<Executable>(&block)->blockContentSubscribe(this); 309 // Initialize ops in reverse order, so we can do as much initial 310 // propagation as possible without having to go through the 311 // solver queue. 312 for (auto it = block.rbegin(); it != block.rend(); it++) 313 if (failed(initializeRecursively(&*it))) 314 return failure(); 315 } 316 } 317 return success(); 318 } 319 320 LogicalResult 321 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { 322 if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) 323 visitOperation(op); 324 else if (llvm::dyn_cast_if_present<Block *>(point)) 325 // For backward dataflow, we don't have to do any work for the blocks 326 // themselves. CFG edges between blocks are processed by the BranchOp 327 // logic in `visitOperation`, and entry blocks for functions are tied 328 // to the CallOp arguments by visitOperation. 329 return success(); 330 else 331 return failure(); 332 return success(); 333 } 334 335 SmallVector<AbstractSparseLattice *> 336 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { 337 SmallVector<AbstractSparseLattice *> resultLattices; 338 resultLattices.reserve(values.size()); 339 for (Value result : values) { 340 AbstractSparseLattice *resultLattice = getLatticeElement(result); 341 resultLattices.push_back(resultLattice); 342 } 343 return resultLattices; 344 } 345 346 SmallVector<const AbstractSparseLattice *> 347 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( 348 ProgramPoint point, ValueRange values) { 349 SmallVector<const AbstractSparseLattice *> resultLattices; 350 resultLattices.reserve(values.size()); 351 for (Value result : values) { 352 const AbstractSparseLattice *resultLattice = 353 getLatticeElementFor(point, result); 354 resultLattices.push_back(resultLattice); 355 } 356 return resultLattices; 357 } 358 359 static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { 360 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size()); 361 } 362 363 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { 364 // If we're in a dead block, bail out. 365 if (!getOrCreate<Executable>(op->getBlock())->isLive()) 366 return; 367 368 SmallVector<AbstractSparseLattice *> operandLattices = 369 getLatticeElements(op->getOperands()); 370 SmallVector<const AbstractSparseLattice *> resultLattices = 371 getLatticeElementsFor(op, op->getResults()); 372 373 // Block arguments of region branch operations flow back into the operands 374 // of the parent op 375 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { 376 visitRegionSuccessors(branch, operandLattices); 377 return; 378 } 379 380 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 381 // Block arguments of successor blocks flow back into our operands. 382 383 // We remember all operands not forwarded to any block in a BitVector. 384 // We can't just cut out a range here, since the non-forwarded ops might 385 // be non-contiguous (if there's more than one successor). 386 BitVector unaccounted(op->getNumOperands(), true); 387 388 for (auto [index, block] : llvm::enumerate(op->getSuccessors())) { 389 SuccessorOperands successorOperands = branch.getSuccessorOperands(index); 390 OperandRange forwarded = successorOperands.getForwardedOperands(); 391 if (!forwarded.empty()) { 392 MutableArrayRef<OpOperand> operands = op->getOpOperands().slice( 393 forwarded.getBeginOperandIndex(), forwarded.size()); 394 for (OpOperand &operand : operands) { 395 unaccounted.reset(operand.getOperandNumber()); 396 if (std::optional<BlockArgument> blockArg = 397 detail::getBranchSuccessorArgument( 398 successorOperands, operand.getOperandNumber(), block)) { 399 meet(getLatticeElement(operand.get()), 400 *getLatticeElementFor(op, *blockArg)); 401 } 402 } 403 } 404 } 405 // Operands not forwarded to successor blocks are typically parameters 406 // of the branch operation itself (for example the boolean for if/else). 407 for (int index : unaccounted.set_bits()) { 408 OpOperand &operand = op->getOpOperand(index); 409 visitBranchOperand(operand); 410 } 411 return; 412 } 413 414 // For function calls, connect the arguments of the entry blocks to the 415 // operands of the call op that are forwarded to these arguments. 416 if (auto call = dyn_cast<CallOpInterface>(op)) { 417 Operation *callableOp = call.resolveCallable(&symbolTable); 418 if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) { 419 // Not all operands of a call op forward to arguments. Such operands are 420 // stored in `unaccounted`. 421 BitVector unaccounted(op->getNumOperands(), true); 422 423 OperandRange argOperands = call.getArgOperands(); 424 MutableArrayRef<OpOperand> argOpOperands = 425 operandsToOpOperands(argOperands); 426 Region *region = callable.getCallableRegion(); 427 if (region && !region->empty()) { 428 Block &block = region->front(); 429 for (auto [blockArg, argOpOperand] : 430 llvm::zip(block.getArguments(), argOpOperands)) { 431 meet(getLatticeElement(argOpOperand.get()), 432 *getLatticeElementFor(op, blockArg)); 433 unaccounted.reset(argOpOperand.getOperandNumber()); 434 } 435 } 436 // Handle the operands of the call op that aren't forwarded to any 437 // arguments. 438 for (int index : unaccounted.set_bits()) { 439 OpOperand &opOperand = op->getOpOperand(index); 440 visitCallOperand(opOperand); 441 } 442 return; 443 } 444 } 445 446 // When the region of an op implementing `RegionBranchOpInterface` has a 447 // terminator implementing `RegionBranchTerminatorOpInterface` or a 448 // return-like terminator, the region's successors' arguments flow back into 449 // the "successor operands" of this terminator. 450 // 451 // A successor operand with respect to an op implementing 452 // `RegionBranchOpInterface` is an operand that is forwarded to a region 453 // successor's input. There are two types of successor operands: the operands 454 // of this op itself and the operands of the terminators of the regions of 455 // this op. 456 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) { 457 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { 458 visitRegionSuccessorsFromTerminator(terminator, branch); 459 return; 460 } 461 } 462 463 if (op->hasTrait<OpTrait::ReturnLike>()) { 464 // Going backwards, the operands of the return are derived from the 465 // results of all CallOps calling this CallableOp. 466 if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { 467 const PredecessorState *callsites = 468 getOrCreateFor<PredecessorState>(op, callable); 469 if (callsites->allPredecessorsKnown()) { 470 for (Operation *call : callsites->getKnownPredecessors()) { 471 SmallVector<const AbstractSparseLattice *> callResultLattices = 472 getLatticeElementsFor(op, call->getResults()); 473 for (auto [op, result] : 474 llvm::zip(operandLattices, callResultLattices)) 475 meet(op, *result); 476 } 477 } else { 478 // If we don't know all the callers, we can't know where the 479 // returned values go. Note that, in particular, this will trigger 480 // for the return ops of any public functions. 481 setAllToExitStates(operandLattices); 482 } 483 return; 484 } 485 } 486 487 visitOperationImpl(op, operandLattices, resultLattices); 488 } 489 490 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( 491 RegionBranchOpInterface branch, 492 ArrayRef<AbstractSparseLattice *> operandLattices) { 493 Operation *op = branch.getOperation(); 494 SmallVector<RegionSuccessor> successors; 495 SmallVector<Attribute> operands(op->getNumOperands(), nullptr); 496 branch.getEntrySuccessorRegions(operands, successors); 497 498 // All operands not forwarded to any successor. This set can be non-contiguous 499 // in the presence of multiple successors. 500 BitVector unaccounted(op->getNumOperands(), true); 501 502 for (RegionSuccessor &successor : successors) { 503 OperandRange operands = branch.getEntrySuccessorOperands(successor); 504 MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands); 505 ValueRange inputs = successor.getSuccessorInputs(); 506 for (auto [operand, input] : llvm::zip(opoperands, inputs)) { 507 meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input)); 508 unaccounted.reset(operand.getOperandNumber()); 509 } 510 } 511 // All operands not forwarded to regions are typically parameters of the 512 // branch operation itself (for example the boolean for if/else). 513 for (int index : unaccounted.set_bits()) { 514 visitBranchOperand(op->getOpOperand(index)); 515 } 516 } 517 518 void AbstractSparseBackwardDataFlowAnalysis:: 519 visitRegionSuccessorsFromTerminator( 520 RegionBranchTerminatorOpInterface terminator, 521 RegionBranchOpInterface branch) { 522 assert(isa<RegionBranchTerminatorOpInterface>(terminator) && 523 "expected a `RegionBranchTerminatorOpInterface` op"); 524 assert(terminator->getParentOp() == branch.getOperation() && 525 "expected `branch` to be the parent op of `terminator`"); 526 527 SmallVector<Attribute> operandAttributes(terminator->getNumOperands(), 528 nullptr); 529 SmallVector<RegionSuccessor> successors; 530 terminator.getSuccessorRegions(operandAttributes, successors); 531 // All operands not forwarded to any successor. This set can be 532 // non-contiguous in the presence of multiple successors. 533 BitVector unaccounted(terminator->getNumOperands(), true); 534 535 for (const RegionSuccessor &successor : successors) { 536 ValueRange inputs = successor.getSuccessorInputs(); 537 OperandRange operands = terminator.getSuccessorOperands(successor); 538 MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands); 539 for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { 540 meet(getLatticeElement(opOperand.get()), 541 *getLatticeElementFor(terminator, input)); 542 unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber()); 543 } 544 } 545 // Visit operands of the branch op not forwarded to the next region. 546 // (Like e.g. the boolean of `scf.conditional`) 547 for (int index : unaccounted.set_bits()) { 548 visitBranchOperand(terminator->getOpOperand(index)); 549 } 550 } 551 552 const AbstractSparseLattice * 553 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, 554 Value value) { 555 AbstractSparseLattice *state = getLatticeElement(value); 556 addDependency(state, point); 557 return state; 558 } 559 560 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( 561 ArrayRef<AbstractSparseLattice *> lattices) { 562 for (AbstractSparseLattice *lattice : lattices) 563 setToExitState(lattice); 564 } 565 566 void AbstractSparseBackwardDataFlowAnalysis::meet( 567 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { 568 propagateIfChanged(lhs, lhs->meet(rhs)); 569 } 570