1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 11 #include "mlir/Interfaces/CallInterfaces.h" 12 13 using namespace mlir; 14 using namespace mlir::dataflow; 15 16 //===----------------------------------------------------------------------===// 17 // AbstractSparseLattice 18 //===----------------------------------------------------------------------===// 19 20 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { 21 // Push all users of the value to the queue. 22 for (Operation *user : point.get<Value>().getUsers()) 23 for (DataFlowAnalysis *analysis : useDefSubscribers) 24 solver->enqueue({user, analysis}); 25 } 26 27 //===----------------------------------------------------------------------===// 28 // AbstractSparseDataFlowAnalysis 29 //===----------------------------------------------------------------------===// 30 31 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis( 32 DataFlowSolver &solver) 33 : DataFlowAnalysis(solver) { 34 registerPointKind<CFGEdge>(); 35 } 36 37 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) { 38 // Mark the entry block arguments as having reached their pessimistic 39 // fixpoints. 40 for (Region ®ion : top->getRegions()) { 41 if (region.empty()) 42 continue; 43 for (Value argument : region.front().getArguments()) 44 setToEntryState(getLatticeElement(argument)); 45 } 46 47 return initializeRecursively(top); 48 } 49 50 LogicalResult 51 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) { 52 // Initialize the analysis by visiting every owner of an SSA value (all 53 // operations and blocks). 54 visitOperation(op); 55 for (Region ®ion : op->getRegions()) { 56 for (Block &block : region) { 57 getOrCreate<Executable>(&block)->blockContentSubscribe(this); 58 visitBlock(&block); 59 for (Operation &op : block) 60 if (failed(initializeRecursively(&op))) 61 return failure(); 62 } 63 } 64 65 return success(); 66 } 67 68 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) { 69 if (Operation *op = point.dyn_cast<Operation *>()) 70 visitOperation(op); 71 else if (Block *block = point.dyn_cast<Block *>()) 72 visitBlock(block); 73 else 74 return failure(); 75 return success(); 76 } 77 78 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) { 79 // Exit early on operations with no results. 80 if (op->getNumResults() == 0) 81 return; 82 83 // If the containing block is not executable, bail out. 84 if (!getOrCreate<Executable>(op->getBlock())->isLive()) 85 return; 86 87 // Get the result lattices. 88 SmallVector<AbstractSparseLattice *> resultLattices; 89 resultLattices.reserve(op->getNumResults()); 90 for (Value result : op->getResults()) { 91 AbstractSparseLattice *resultLattice = getLatticeElement(result); 92 resultLattices.push_back(resultLattice); 93 } 94 95 // The results of a region branch operation are determined by control-flow. 96 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { 97 return visitRegionSuccessors({branch}, branch, 98 /*successorIndex=*/std::nullopt, 99 resultLattices); 100 } 101 102 // The results of a call operation are determined by the callgraph. 103 if (auto call = dyn_cast<CallOpInterface>(op)) { 104 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); 105 // If not all return sites are known, then conservatively assume we can't 106 // reason about the data-flow. 107 if (!predecessors->allPredecessorsKnown()) 108 return setAllToEntryStates(resultLattices); 109 for (Operation *predecessor : predecessors->getKnownPredecessors()) 110 for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) 111 join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); 112 return; 113 } 114 115 // Grab the lattice elements of the operands. 116 SmallVector<const AbstractSparseLattice *> operandLattices; 117 operandLattices.reserve(op->getNumOperands()); 118 for (Value operand : op->getOperands()) { 119 AbstractSparseLattice *operandLattice = getLatticeElement(operand); 120 operandLattice->useDefSubscribe(this); 121 operandLattices.push_back(operandLattice); 122 } 123 124 // Invoke the operation transfer function. 125 visitOperationImpl(op, operandLattices, resultLattices); 126 } 127 128 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) { 129 // Exit early on blocks with no arguments. 130 if (block->getNumArguments() == 0) 131 return; 132 133 // If the block is not executable, bail out. 134 if (!getOrCreate<Executable>(block)->isLive()) 135 return; 136 137 // Get the argument lattices. 138 SmallVector<AbstractSparseLattice *> argLattices; 139 argLattices.reserve(block->getNumArguments()); 140 for (BlockArgument argument : block->getArguments()) { 141 AbstractSparseLattice *argLattice = getLatticeElement(argument); 142 argLattices.push_back(argLattice); 143 } 144 145 // The argument lattices of entry blocks are set by region control-flow or the 146 // callgraph. 147 if (block->isEntryBlock()) { 148 // Check if this block is the entry block of a callable region. 149 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); 150 if (callable && callable.getCallableRegion() == block->getParent()) { 151 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); 152 // If not all callsites are known, conservatively mark all lattices as 153 // having reached their pessimistic fixpoints. 154 if (!callsites->allPredecessorsKnown()) 155 return setAllToEntryStates(argLattices); 156 for (Operation *callsite : callsites->getKnownPredecessors()) { 157 auto call = cast<CallOpInterface>(callsite); 158 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) 159 join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); 160 } 161 return; 162 } 163 164 // Check if the lattices can be determined from region control flow. 165 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { 166 return visitRegionSuccessors( 167 block, branch, block->getParent()->getRegionNumber(), argLattices); 168 } 169 170 // Otherwise, we can't reason about the data-flow. 171 return visitNonControlFlowArgumentsImpl(block->getParentOp(), 172 RegionSuccessor(block->getParent()), 173 argLattices, /*firstIndex=*/0); 174 } 175 176 // Iterate over the predecessors of the non-entry block. 177 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); 178 it != e; ++it) { 179 Block *predecessor = *it; 180 181 // If the edge from the predecessor block to the current block is not live, 182 // bail out. 183 auto *edgeExecutable = 184 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block)); 185 edgeExecutable->blockContentSubscribe(this); 186 if (!edgeExecutable->isLive()) 187 continue; 188 189 // Check if we can reason about the data-flow from the predecessor. 190 if (auto branch = 191 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) { 192 SuccessorOperands operands = 193 branch.getSuccessorOperands(it.getSuccessorIndex()); 194 for (auto &it : llvm::enumerate(argLattices)) { 195 if (Value operand = operands[it.index()]) { 196 join(it.value(), *getLatticeElementFor(block, operand)); 197 } else { 198 // Conservatively consider internally produced arguments as entry 199 // points. 200 setAllToEntryStates(it.value()); 201 } 202 } 203 } else { 204 return setAllToEntryStates(argLattices); 205 } 206 } 207 } 208 209 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( 210 ProgramPoint point, RegionBranchOpInterface branch, 211 std::optional<unsigned> successorIndex, 212 ArrayRef<AbstractSparseLattice *> lattices) { 213 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); 214 assert(predecessors->allPredecessorsKnown() && 215 "unexpected unresolved region successors"); 216 217 for (Operation *op : predecessors->getKnownPredecessors()) { 218 // Get the incoming successor operands. 219 std::optional<OperandRange> operands; 220 221 // Check if the predecessor is the parent op. 222 if (op == branch) { 223 operands = branch.getSuccessorEntryOperands(successorIndex); 224 // Otherwise, try to deduce the operands from a region return-like op. 225 } else { 226 if (isRegionReturnLike(op)) 227 operands = getRegionBranchSuccessorOperands(op, successorIndex); 228 } 229 230 if (!operands) { 231 // We can't reason about the data-flow. 232 return setAllToEntryStates(lattices); 233 } 234 235 ValueRange inputs = predecessors->getSuccessorInputs(op); 236 assert(inputs.size() == operands->size() && 237 "expected the same number of successor inputs as operands"); 238 239 unsigned firstIndex = 0; 240 if (inputs.size() != lattices.size()) { 241 if (point.dyn_cast<Operation *>()) { 242 if (!inputs.empty()) 243 firstIndex = inputs.front().cast<OpResult>().getResultNumber(); 244 visitNonControlFlowArgumentsImpl( 245 branch, 246 RegionSuccessor( 247 branch->getResults().slice(firstIndex, inputs.size())), 248 lattices, firstIndex); 249 } else { 250 if (!inputs.empty()) 251 firstIndex = inputs.front().cast<BlockArgument>().getArgNumber(); 252 Region *region = point.get<Block *>()->getParent(); 253 visitNonControlFlowArgumentsImpl( 254 branch, 255 RegionSuccessor(region, region->getArguments().slice( 256 firstIndex, inputs.size())), 257 lattices, firstIndex); 258 } 259 } 260 261 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) 262 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); 263 } 264 } 265 266 const AbstractSparseLattice * 267 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, 268 Value value) { 269 AbstractSparseLattice *state = getLatticeElement(value); 270 addDependency(state, point); 271 return state; 272 } 273 274 void AbstractSparseDataFlowAnalysis::setAllToEntryStates( 275 ArrayRef<AbstractSparseLattice *> lattices) { 276 for (AbstractSparseLattice *lattice : lattices) 277 setToEntryState(lattice); 278 } 279 280 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs, 281 const AbstractSparseLattice &rhs) { 282 propagateIfChanged(lhs, lhs->join(rhs)); 283 } 284 285 //===----------------------------------------------------------------------===// 286 // AbstractSparseBackwardDataFlowAnalysis 287 //===----------------------------------------------------------------------===// 288 289 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis( 290 DataFlowSolver &solver, SymbolTableCollection &symbolTable) 291 : DataFlowAnalysis(solver), symbolTable(symbolTable) { 292 registerPointKind<CFGEdge>(); 293 } 294 295 LogicalResult 296 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) { 297 return initializeRecursively(top); 298 } 299 300 LogicalResult 301 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { 302 visitOperation(op); 303 for (Region ®ion : op->getRegions()) { 304 for (Block &block : region) { 305 getOrCreate<Executable>(&block)->blockContentSubscribe(this); 306 // Initialize ops in reverse order, so we can do as much initial 307 // propagation as possible without having to go through the 308 // solver queue. 309 for (auto it = block.rbegin(); it != block.rend(); it++) 310 if (failed(initializeRecursively(&*it))) 311 return failure(); 312 } 313 } 314 return success(); 315 } 316 317 LogicalResult 318 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { 319 if (Operation *op = point.dyn_cast<Operation *>()) 320 visitOperation(op); 321 else if (point.dyn_cast<Block *>()) 322 // For backward dataflow, we don't have to do any work for the blocks 323 // themselves. CFG edges between blocks are processed by the BranchOp 324 // logic in `visitOperation`, and entry blocks for functions are tied 325 // to the CallOp arguments by visitOperation. 326 return success(); 327 else 328 return failure(); 329 return success(); 330 } 331 332 SmallVector<AbstractSparseLattice *> 333 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { 334 SmallVector<AbstractSparseLattice *> resultLattices; 335 resultLattices.reserve(values.size()); 336 for (Value result : values) { 337 AbstractSparseLattice *resultLattice = getLatticeElement(result); 338 resultLattices.push_back(resultLattice); 339 } 340 return resultLattices; 341 } 342 343 SmallVector<const AbstractSparseLattice *> 344 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( 345 ProgramPoint point, ValueRange values) { 346 SmallVector<const AbstractSparseLattice *> resultLattices; 347 resultLattices.reserve(values.size()); 348 for (Value result : values) { 349 const AbstractSparseLattice *resultLattice = 350 getLatticeElementFor(point, result); 351 resultLattices.push_back(resultLattice); 352 } 353 return resultLattices; 354 } 355 356 static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { 357 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size()); 358 } 359 360 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { 361 // If we're in a dead block, bail out. 362 if (!getOrCreate<Executable>(op->getBlock())->isLive()) 363 return; 364 365 SmallVector<AbstractSparseLattice *> operandLattices = 366 getLatticeElements(op->getOperands()); 367 SmallVector<const AbstractSparseLattice *> resultLattices = 368 getLatticeElementsFor(op, op->getResults()); 369 370 // Block arguments of region branch operations flow back into the operands 371 // of the parent op 372 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { 373 visitRegionSuccessors(branch, operandLattices); 374 return; 375 } 376 377 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 378 // Block arguments of successor blocks flow back into our operands. 379 380 // We remember all operands not forwarded to any block in a BitVector. 381 // We can't just cut out a range here, since the non-forwarded ops might 382 // be non-contiguous (if there's more than one successor). 383 BitVector unaccounted(op->getNumOperands(), true); 384 385 for (auto [index, block] : llvm::enumerate(op->getSuccessors())) { 386 SuccessorOperands successorOperands = branch.getSuccessorOperands(index); 387 OperandRange forwarded = successorOperands.getForwardedOperands(); 388 if (!forwarded.empty()) { 389 MutableArrayRef<OpOperand> operands = op->getOpOperands().slice( 390 forwarded.getBeginOperandIndex(), forwarded.size()); 391 for (OpOperand &operand : operands) { 392 unaccounted.reset(operand.getOperandNumber()); 393 if (std::optional<BlockArgument> blockArg = 394 detail::getBranchSuccessorArgument( 395 successorOperands, operand.getOperandNumber(), block)) { 396 meet(getLatticeElement(operand.get()), 397 *getLatticeElementFor(op, *blockArg)); 398 } 399 } 400 } 401 } 402 // Operands not forwarded to successor blocks are typically parameters 403 // of the branch operation itself (for example the boolean for if/else). 404 for (int index : unaccounted.set_bits()) { 405 OpOperand &operand = op->getOpOperand(index); 406 visitBranchOperand(operand); 407 } 408 return; 409 } 410 411 // For function calls, connect the arguments of the entry blocks 412 // to the operands of the call op. 413 if (auto call = dyn_cast<CallOpInterface>(op)) { 414 Operation *callableOp = call.resolveCallable(&symbolTable); 415 if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) { 416 Region *region = callable.getCallableRegion(); 417 if (!region->empty()) { 418 Block &block = region->front(); 419 for (auto [blockArg, operand] : 420 llvm::zip(block.getArguments(), operandLattices)) { 421 meet(operand, *getLatticeElementFor(op, blockArg)); 422 } 423 } 424 return; 425 } 426 } 427 428 // The block arguments of the branched to region flow back into the 429 // operands of the yield operation. 430 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) { 431 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { 432 SmallVector<RegionSuccessor> successors; 433 SmallVector<Attribute> operands(op->getNumOperands(), nullptr); 434 branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(), 435 operands, successors); 436 // All operands not forwarded to any successor. This set can be 437 // non-contiguous in the presence of multiple successors. 438 BitVector unaccounted(op->getNumOperands(), true); 439 440 for (const RegionSuccessor &successor : successors) { 441 ValueRange inputs = successor.getSuccessorInputs(); 442 Region *region = successor.getSuccessor(); 443 OperandRange operands = 444 region ? terminator.getSuccessorOperands(region->getRegionNumber()) 445 : terminator.getSuccessorOperands({}); 446 MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands); 447 for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) { 448 meet(getLatticeElement(opoperand.get()), 449 *getLatticeElementFor(op, input)); 450 unaccounted.reset( 451 const_cast<OpOperand &>(opoperand).getOperandNumber()); 452 } 453 } 454 // Visit operands of the branch op not forwarded to the next region. 455 // (Like e.g. the boolean of `scf.conditional`) 456 for (int index : unaccounted.set_bits()) { 457 visitBranchOperand(op->getOpOperand(index)); 458 } 459 return; 460 } 461 } 462 463 // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`, 464 // since they behave like a return in the sense that they forward to the 465 // results of some other (here: the parent) op. 466 if (op->hasTrait<OpTrait::ReturnLike>()) { 467 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { 468 OperandRange operands = op->getOperands(); 469 ResultRange results = op->getParentOp()->getResults(); 470 assert(results.size() == operands.size() && 471 "Can't derive arg mapping for yield-like op."); 472 for (auto [operand, result] : llvm::zip(operands, results)) 473 meet(getLatticeElement(operand), *getLatticeElementFor(op, result)); 474 return; 475 } 476 477 // Going backwards, the operands of the return are derived from the 478 // results of all CallOps calling this CallableOp. 479 if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { 480 const PredecessorState *callsites = 481 getOrCreateFor<PredecessorState>(op, callable); 482 if (callsites->allPredecessorsKnown()) { 483 for (Operation *call : callsites->getKnownPredecessors()) { 484 SmallVector<const AbstractSparseLattice *> callResultLattices = 485 getLatticeElementsFor(op, call->getResults()); 486 for (auto [op, result] : 487 llvm::zip(operandLattices, callResultLattices)) 488 meet(op, *result); 489 } 490 } else { 491 // If we don't know all the callers, we can't know where the 492 // returned values go. Note that, in particular, this will trigger 493 // for the return ops of any public functions. 494 setAllToExitStates(operandLattices); 495 } 496 return; 497 } 498 } 499 500 visitOperationImpl(op, operandLattices, resultLattices); 501 } 502 503 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( 504 RegionBranchOpInterface branch, 505 ArrayRef<AbstractSparseLattice *> operandLattices) { 506 Operation *op = branch.getOperation(); 507 SmallVector<RegionSuccessor> successors; 508 SmallVector<Attribute> operands(op->getNumOperands(), nullptr); 509 branch.getSuccessorRegions(/*index=*/{}, operands, successors); 510 511 // All operands not forwarded to any successor. This set can be non-contiguous 512 // in the presence of multiple successors. 513 BitVector unaccounted(op->getNumOperands(), true); 514 515 for (RegionSuccessor &successor : successors) { 516 Region *region = successor.getSuccessor(); 517 OperandRange operands = 518 region ? branch.getSuccessorEntryOperands(region->getRegionNumber()) 519 : branch.getSuccessorEntryOperands({}); 520 MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands); 521 ValueRange inputs = successor.getSuccessorInputs(); 522 for (auto [operand, input] : llvm::zip(opoperands, inputs)) { 523 meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input)); 524 unaccounted.reset(operand.getOperandNumber()); 525 } 526 } 527 // All operands not forwarded to regions are typically parameters of the 528 // branch operation itself (for example the boolean for if/else). 529 for (int index : unaccounted.set_bits()) { 530 visitBranchOperand(op->getOpOperand(index)); 531 } 532 } 533 534 const AbstractSparseLattice * 535 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, 536 Value value) { 537 AbstractSparseLattice *state = getLatticeElement(value); 538 addDependency(state, point); 539 return state; 540 } 541 542 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( 543 ArrayRef<AbstractSparseLattice *> lattices) { 544 for (AbstractSparseLattice *lattice : lattices) 545 setToExitState(lattice); 546 } 547 548 void AbstractSparseBackwardDataFlowAnalysis::meet( 549 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { 550 propagateIfChanged(lhs, lhs->meet(rhs)); 551 } 552