114153654SChristian Ulmann #include "mlir/Analysis/SliceWalk.h" 214153654SChristian Ulmann #include "mlir/Interfaces/ControlFlowInterfaces.h" 314153654SChristian Ulmann 414153654SChristian Ulmann using namespace mlir; 514153654SChristian Ulmann 614153654SChristian Ulmann WalkContinuation mlir::walkSlice(ValueRange rootValues, 714153654SChristian Ulmann WalkCallback walkCallback) { 814153654SChristian Ulmann // Search the backward slice starting from the root values. 914153654SChristian Ulmann SmallVector<Value> workList = rootValues; 1014153654SChristian Ulmann llvm::SmallDenseSet<Value, 16> seenValues; 1114153654SChristian Ulmann while (!workList.empty()) { 1214153654SChristian Ulmann // Search the backward slice of the current value. 1314153654SChristian Ulmann Value current = workList.pop_back_val(); 1414153654SChristian Ulmann 1514153654SChristian Ulmann // Skip the current value if it has already been seen. 1614153654SChristian Ulmann if (!seenValues.insert(current).second) 1714153654SChristian Ulmann continue; 1814153654SChristian Ulmann 1914153654SChristian Ulmann // Call the walk callback with the current value. 2014153654SChristian Ulmann WalkContinuation continuation = walkCallback(current); 2114153654SChristian Ulmann if (continuation.wasInterrupted()) 2214153654SChristian Ulmann return continuation; 2314153654SChristian Ulmann if (continuation.wasSkipped()) 2414153654SChristian Ulmann continue; 2514153654SChristian Ulmann 2614153654SChristian Ulmann assert(continuation.wasAdvancedTo()); 2714153654SChristian Ulmann // Add the next values to the work list if the walk should continue. 2814153654SChristian Ulmann workList.append(continuation.getNextValues().begin(), 2914153654SChristian Ulmann continuation.getNextValues().end()); 3014153654SChristian Ulmann } 3114153654SChristian Ulmann 3214153654SChristian Ulmann return WalkContinuation::skip(); 3314153654SChristian Ulmann } 3414153654SChristian Ulmann 3514153654SChristian Ulmann /// Returns the operands from all predecessor regions that match `operandNumber` 3614153654SChristian Ulmann /// for the `successor` region within `regionOp`. 3714153654SChristian Ulmann static SmallVector<Value> 3814153654SChristian Ulmann getRegionPredecessorOperands(RegionBranchOpInterface regionOp, 3914153654SChristian Ulmann RegionSuccessor successor, 4014153654SChristian Ulmann unsigned operandNumber) { 4114153654SChristian Ulmann SmallVector<Value> predecessorOperands; 4214153654SChristian Ulmann 4314153654SChristian Ulmann // Returns true if `successors` contains `successor`. 4414153654SChristian Ulmann auto isContained = [](ArrayRef<RegionSuccessor> successors, 4514153654SChristian Ulmann RegionSuccessor successor) { 4614153654SChristian Ulmann auto *it = llvm::find_if(successors, [&successor](RegionSuccessor curr) { 4714153654SChristian Ulmann return curr.getSuccessor() == successor.getSuccessor(); 4814153654SChristian Ulmann }); 4914153654SChristian Ulmann return it != successors.end(); 5014153654SChristian Ulmann }; 5114153654SChristian Ulmann 5214153654SChristian Ulmann // Search the operand ranges on the region operation itself. 5314153654SChristian Ulmann SmallVector<Attribute> operandAttributes(regionOp->getNumOperands()); 5414153654SChristian Ulmann SmallVector<RegionSuccessor> successors; 5514153654SChristian Ulmann regionOp.getEntrySuccessorRegions(operandAttributes, successors); 5614153654SChristian Ulmann if (isContained(successors, successor)) { 5714153654SChristian Ulmann OperandRange operands = regionOp.getEntrySuccessorOperands(successor); 5814153654SChristian Ulmann predecessorOperands.push_back(operands[operandNumber]); 5914153654SChristian Ulmann } 6014153654SChristian Ulmann 6114153654SChristian Ulmann // Search the operand ranges on region terminators. 6214153654SChristian Ulmann for (Region ®ion : regionOp->getRegions()) { 6314153654SChristian Ulmann for (Block &block : region) { 6414153654SChristian Ulmann auto terminatorOp = 6514153654SChristian Ulmann dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator()); 6614153654SChristian Ulmann if (!terminatorOp) 6714153654SChristian Ulmann continue; 6814153654SChristian Ulmann SmallVector<Attribute> operandAttributes(terminatorOp->getNumOperands()); 6914153654SChristian Ulmann SmallVector<RegionSuccessor> successors; 7014153654SChristian Ulmann terminatorOp.getSuccessorRegions(operandAttributes, successors); 7114153654SChristian Ulmann if (isContained(successors, successor)) { 7214153654SChristian Ulmann OperandRange operands = terminatorOp.getSuccessorOperands(successor); 7314153654SChristian Ulmann predecessorOperands.push_back(operands[operandNumber]); 7414153654SChristian Ulmann } 7514153654SChristian Ulmann } 7614153654SChristian Ulmann } 7714153654SChristian Ulmann 7814153654SChristian Ulmann return predecessorOperands; 7914153654SChristian Ulmann } 8014153654SChristian Ulmann 8114153654SChristian Ulmann /// Returns the predecessor branch operands that match `blockArg`, or nullopt if 8214153654SChristian Ulmann /// some of the predecessor terminators do not implement the BranchOpInterface. 8314153654SChristian Ulmann static std::optional<SmallVector<Value>> 8414153654SChristian Ulmann getBlockPredecessorOperands(BlockArgument blockArg) { 8514153654SChristian Ulmann Block *block = blockArg.getOwner(); 8614153654SChristian Ulmann 8714153654SChristian Ulmann // Search the predecessor operands for all predecessor terminators. 8814153654SChristian Ulmann SmallVector<Value> predecessorOperands; 8914153654SChristian Ulmann for (auto it = block->pred_begin(); it != block->pred_end(); ++it) { 9014153654SChristian Ulmann Block *predecessor = *it; 9114153654SChristian Ulmann auto branchOp = dyn_cast<BranchOpInterface>(predecessor->getTerminator()); 9214153654SChristian Ulmann if (!branchOp) 9314153654SChristian Ulmann return std::nullopt; 9414153654SChristian Ulmann SuccessorOperands successorOperands = 9514153654SChristian Ulmann branchOp.getSuccessorOperands(it.getSuccessorIndex()); 9614153654SChristian Ulmann // Store the predecessor operand if the block argument matches an operand 9714153654SChristian Ulmann // and is not produced by the terminator. 9814153654SChristian Ulmann if (Value operand = successorOperands[blockArg.getArgNumber()]) 9914153654SChristian Ulmann predecessorOperands.push_back(operand); 10014153654SChristian Ulmann } 10114153654SChristian Ulmann 10214153654SChristian Ulmann return predecessorOperands; 10314153654SChristian Ulmann } 10414153654SChristian Ulmann 10514153654SChristian Ulmann std::optional<SmallVector<Value>> 10614153654SChristian Ulmann mlir::getControlFlowPredecessors(Value value) { 10714153654SChristian Ulmann if (OpResult opResult = dyn_cast<OpResult>(value)) { 108*bf68e904SChristian Ulmann if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>()) 109*bf68e904SChristian Ulmann return SmallVector<Value>( 110*bf68e904SChristian Ulmann {selectOp.getTrueValue(), selectOp.getFalseValue()}); 111*bf68e904SChristian Ulmann auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>(); 11214153654SChristian Ulmann // If the interface is not implemented, there are no control flow 11314153654SChristian Ulmann // predecessors to work with. 11414153654SChristian Ulmann if (!regionOp) 11514153654SChristian Ulmann return std::nullopt; 11614153654SChristian Ulmann // Add the control flow predecessor operands to the work list. 11714153654SChristian Ulmann RegionSuccessor region(regionOp->getResults()); 11814153654SChristian Ulmann SmallVector<Value> predecessorOperands = getRegionPredecessorOperands( 11914153654SChristian Ulmann regionOp, region, opResult.getResultNumber()); 12014153654SChristian Ulmann return predecessorOperands; 12114153654SChristian Ulmann } 12214153654SChristian Ulmann 12314153654SChristian Ulmann auto blockArg = cast<BlockArgument>(value); 12414153654SChristian Ulmann Block *block = blockArg.getOwner(); 12514153654SChristian Ulmann // Search the region predecessor operands for structured control flow. 12614153654SChristian Ulmann if (block->isEntryBlock()) { 12714153654SChristian Ulmann if (auto regionBranchOp = 12814153654SChristian Ulmann dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { 12914153654SChristian Ulmann RegionSuccessor region(blockArg.getParentRegion()); 13014153654SChristian Ulmann SmallVector<Value> predecessorOperands = getRegionPredecessorOperands( 13114153654SChristian Ulmann regionBranchOp, region, blockArg.getArgNumber()); 13214153654SChristian Ulmann return predecessorOperands; 13314153654SChristian Ulmann } 13414153654SChristian Ulmann // If the interface is not implemented, there are no control flow 13514153654SChristian Ulmann // predecessors to work with. 13614153654SChristian Ulmann return std::nullopt; 13714153654SChristian Ulmann } 13814153654SChristian Ulmann 13914153654SChristian Ulmann // Search the block predecessor operands for unstructured control flow. 14014153654SChristian Ulmann return getBlockPredecessorOperands(blockArg); 14114153654SChristian Ulmann } 142