xref: /llvm-project/mlir/lib/Analysis/SliceWalk.cpp (revision bf68e9047f62c22ca87f9a4a7c59a46b3de06abb)
114153654SChristian Ulmann #include "mlir/Analysis/SliceWalk.h"
214153654SChristian Ulmann #include "mlir/Interfaces/ControlFlowInterfaces.h"
314153654SChristian Ulmann 
414153654SChristian Ulmann using namespace mlir;
514153654SChristian Ulmann 
614153654SChristian Ulmann WalkContinuation mlir::walkSlice(ValueRange rootValues,
714153654SChristian Ulmann                                  WalkCallback walkCallback) {
814153654SChristian Ulmann   // Search the backward slice starting from the root values.
914153654SChristian Ulmann   SmallVector<Value> workList = rootValues;
1014153654SChristian Ulmann   llvm::SmallDenseSet<Value, 16> seenValues;
1114153654SChristian Ulmann   while (!workList.empty()) {
1214153654SChristian Ulmann     // Search the backward slice of the current value.
1314153654SChristian Ulmann     Value current = workList.pop_back_val();
1414153654SChristian Ulmann 
1514153654SChristian Ulmann     // Skip the current value if it has already been seen.
1614153654SChristian Ulmann     if (!seenValues.insert(current).second)
1714153654SChristian Ulmann       continue;
1814153654SChristian Ulmann 
1914153654SChristian Ulmann     // Call the walk callback with the current value.
2014153654SChristian Ulmann     WalkContinuation continuation = walkCallback(current);
2114153654SChristian Ulmann     if (continuation.wasInterrupted())
2214153654SChristian Ulmann       return continuation;
2314153654SChristian Ulmann     if (continuation.wasSkipped())
2414153654SChristian Ulmann       continue;
2514153654SChristian Ulmann 
2614153654SChristian Ulmann     assert(continuation.wasAdvancedTo());
2714153654SChristian Ulmann     // Add the next values to the work list if the walk should continue.
2814153654SChristian Ulmann     workList.append(continuation.getNextValues().begin(),
2914153654SChristian Ulmann                     continuation.getNextValues().end());
3014153654SChristian Ulmann   }
3114153654SChristian Ulmann 
3214153654SChristian Ulmann   return WalkContinuation::skip();
3314153654SChristian Ulmann }
3414153654SChristian Ulmann 
3514153654SChristian Ulmann /// Returns the operands from all predecessor regions that match `operandNumber`
3614153654SChristian Ulmann /// for the `successor` region within `regionOp`.
3714153654SChristian Ulmann static SmallVector<Value>
3814153654SChristian Ulmann getRegionPredecessorOperands(RegionBranchOpInterface regionOp,
3914153654SChristian Ulmann                              RegionSuccessor successor,
4014153654SChristian Ulmann                              unsigned operandNumber) {
4114153654SChristian Ulmann   SmallVector<Value> predecessorOperands;
4214153654SChristian Ulmann 
4314153654SChristian Ulmann   // Returns true if `successors` contains `successor`.
4414153654SChristian Ulmann   auto isContained = [](ArrayRef<RegionSuccessor> successors,
4514153654SChristian Ulmann                         RegionSuccessor successor) {
4614153654SChristian Ulmann     auto *it = llvm::find_if(successors, [&successor](RegionSuccessor curr) {
4714153654SChristian Ulmann       return curr.getSuccessor() == successor.getSuccessor();
4814153654SChristian Ulmann     });
4914153654SChristian Ulmann     return it != successors.end();
5014153654SChristian Ulmann   };
5114153654SChristian Ulmann 
5214153654SChristian Ulmann   // Search the operand ranges on the region operation itself.
5314153654SChristian Ulmann   SmallVector<Attribute> operandAttributes(regionOp->getNumOperands());
5414153654SChristian Ulmann   SmallVector<RegionSuccessor> successors;
5514153654SChristian Ulmann   regionOp.getEntrySuccessorRegions(operandAttributes, successors);
5614153654SChristian Ulmann   if (isContained(successors, successor)) {
5714153654SChristian Ulmann     OperandRange operands = regionOp.getEntrySuccessorOperands(successor);
5814153654SChristian Ulmann     predecessorOperands.push_back(operands[operandNumber]);
5914153654SChristian Ulmann   }
6014153654SChristian Ulmann 
6114153654SChristian Ulmann   // Search the operand ranges on region terminators.
6214153654SChristian Ulmann   for (Region &region : regionOp->getRegions()) {
6314153654SChristian Ulmann     for (Block &block : region) {
6414153654SChristian Ulmann       auto terminatorOp =
6514153654SChristian Ulmann           dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
6614153654SChristian Ulmann       if (!terminatorOp)
6714153654SChristian Ulmann         continue;
6814153654SChristian Ulmann       SmallVector<Attribute> operandAttributes(terminatorOp->getNumOperands());
6914153654SChristian Ulmann       SmallVector<RegionSuccessor> successors;
7014153654SChristian Ulmann       terminatorOp.getSuccessorRegions(operandAttributes, successors);
7114153654SChristian Ulmann       if (isContained(successors, successor)) {
7214153654SChristian Ulmann         OperandRange operands = terminatorOp.getSuccessorOperands(successor);
7314153654SChristian Ulmann         predecessorOperands.push_back(operands[operandNumber]);
7414153654SChristian Ulmann       }
7514153654SChristian Ulmann     }
7614153654SChristian Ulmann   }
7714153654SChristian Ulmann 
7814153654SChristian Ulmann   return predecessorOperands;
7914153654SChristian Ulmann }
8014153654SChristian Ulmann 
8114153654SChristian Ulmann /// Returns the predecessor branch operands that match `blockArg`, or nullopt if
8214153654SChristian Ulmann /// some of the predecessor terminators do not implement the BranchOpInterface.
8314153654SChristian Ulmann static std::optional<SmallVector<Value>>
8414153654SChristian Ulmann getBlockPredecessorOperands(BlockArgument blockArg) {
8514153654SChristian Ulmann   Block *block = blockArg.getOwner();
8614153654SChristian Ulmann 
8714153654SChristian Ulmann   // Search the predecessor operands for all predecessor terminators.
8814153654SChristian Ulmann   SmallVector<Value> predecessorOperands;
8914153654SChristian Ulmann   for (auto it = block->pred_begin(); it != block->pred_end(); ++it) {
9014153654SChristian Ulmann     Block *predecessor = *it;
9114153654SChristian Ulmann     auto branchOp = dyn_cast<BranchOpInterface>(predecessor->getTerminator());
9214153654SChristian Ulmann     if (!branchOp)
9314153654SChristian Ulmann       return std::nullopt;
9414153654SChristian Ulmann     SuccessorOperands successorOperands =
9514153654SChristian Ulmann         branchOp.getSuccessorOperands(it.getSuccessorIndex());
9614153654SChristian Ulmann     // Store the predecessor operand if the block argument matches an operand
9714153654SChristian Ulmann     // and is not produced by the terminator.
9814153654SChristian Ulmann     if (Value operand = successorOperands[blockArg.getArgNumber()])
9914153654SChristian Ulmann       predecessorOperands.push_back(operand);
10014153654SChristian Ulmann   }
10114153654SChristian Ulmann 
10214153654SChristian Ulmann   return predecessorOperands;
10314153654SChristian Ulmann }
10414153654SChristian Ulmann 
10514153654SChristian Ulmann std::optional<SmallVector<Value>>
10614153654SChristian Ulmann mlir::getControlFlowPredecessors(Value value) {
10714153654SChristian Ulmann   if (OpResult opResult = dyn_cast<OpResult>(value)) {
108*bf68e904SChristian Ulmann     if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>())
109*bf68e904SChristian Ulmann       return SmallVector<Value>(
110*bf68e904SChristian Ulmann           {selectOp.getTrueValue(), selectOp.getFalseValue()});
111*bf68e904SChristian Ulmann     auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
11214153654SChristian Ulmann     // If the interface is not implemented, there are no control flow
11314153654SChristian Ulmann     // predecessors to work with.
11414153654SChristian Ulmann     if (!regionOp)
11514153654SChristian Ulmann       return std::nullopt;
11614153654SChristian Ulmann     // Add the control flow predecessor operands to the work list.
11714153654SChristian Ulmann     RegionSuccessor region(regionOp->getResults());
11814153654SChristian Ulmann     SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
11914153654SChristian Ulmann         regionOp, region, opResult.getResultNumber());
12014153654SChristian Ulmann     return predecessorOperands;
12114153654SChristian Ulmann   }
12214153654SChristian Ulmann 
12314153654SChristian Ulmann   auto blockArg = cast<BlockArgument>(value);
12414153654SChristian Ulmann   Block *block = blockArg.getOwner();
12514153654SChristian Ulmann   // Search the region predecessor operands for structured control flow.
12614153654SChristian Ulmann   if (block->isEntryBlock()) {
12714153654SChristian Ulmann     if (auto regionBranchOp =
12814153654SChristian Ulmann             dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
12914153654SChristian Ulmann       RegionSuccessor region(blockArg.getParentRegion());
13014153654SChristian Ulmann       SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
13114153654SChristian Ulmann           regionBranchOp, region, blockArg.getArgNumber());
13214153654SChristian Ulmann       return predecessorOperands;
13314153654SChristian Ulmann     }
13414153654SChristian Ulmann     // If the interface is not implemented, there are no control flow
13514153654SChristian Ulmann     // predecessors to work with.
13614153654SChristian Ulmann     return std::nullopt;
13714153654SChristian Ulmann   }
13814153654SChristian Ulmann 
13914153654SChristian Ulmann   // Search the block predecessor operands for unstructured control flow.
14014153654SChristian Ulmann   return getBlockPredecessorOperands(blockArg);
14114153654SChristian Ulmann }
142