xref: /llvm-project/mlir/lib/Analysis/SliceWalk.cpp (revision bf68e9047f62c22ca87f9a4a7c59a46b3de06abb)
1 #include "mlir/Analysis/SliceWalk.h"
2 #include "mlir/Interfaces/ControlFlowInterfaces.h"
3 
4 using namespace mlir;
5 
6 WalkContinuation mlir::walkSlice(ValueRange rootValues,
7                                  WalkCallback walkCallback) {
8   // Search the backward slice starting from the root values.
9   SmallVector<Value> workList = rootValues;
10   llvm::SmallDenseSet<Value, 16> seenValues;
11   while (!workList.empty()) {
12     // Search the backward slice of the current value.
13     Value current = workList.pop_back_val();
14 
15     // Skip the current value if it has already been seen.
16     if (!seenValues.insert(current).second)
17       continue;
18 
19     // Call the walk callback with the current value.
20     WalkContinuation continuation = walkCallback(current);
21     if (continuation.wasInterrupted())
22       return continuation;
23     if (continuation.wasSkipped())
24       continue;
25 
26     assert(continuation.wasAdvancedTo());
27     // Add the next values to the work list if the walk should continue.
28     workList.append(continuation.getNextValues().begin(),
29                     continuation.getNextValues().end());
30   }
31 
32   return WalkContinuation::skip();
33 }
34 
35 /// Returns the operands from all predecessor regions that match `operandNumber`
36 /// for the `successor` region within `regionOp`.
37 static SmallVector<Value>
38 getRegionPredecessorOperands(RegionBranchOpInterface regionOp,
39                              RegionSuccessor successor,
40                              unsigned operandNumber) {
41   SmallVector<Value> predecessorOperands;
42 
43   // Returns true if `successors` contains `successor`.
44   auto isContained = [](ArrayRef<RegionSuccessor> successors,
45                         RegionSuccessor successor) {
46     auto *it = llvm::find_if(successors, [&successor](RegionSuccessor curr) {
47       return curr.getSuccessor() == successor.getSuccessor();
48     });
49     return it != successors.end();
50   };
51 
52   // Search the operand ranges on the region operation itself.
53   SmallVector<Attribute> operandAttributes(regionOp->getNumOperands());
54   SmallVector<RegionSuccessor> successors;
55   regionOp.getEntrySuccessorRegions(operandAttributes, successors);
56   if (isContained(successors, successor)) {
57     OperandRange operands = regionOp.getEntrySuccessorOperands(successor);
58     predecessorOperands.push_back(operands[operandNumber]);
59   }
60 
61   // Search the operand ranges on region terminators.
62   for (Region &region : regionOp->getRegions()) {
63     for (Block &block : region) {
64       auto terminatorOp =
65           dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
66       if (!terminatorOp)
67         continue;
68       SmallVector<Attribute> operandAttributes(terminatorOp->getNumOperands());
69       SmallVector<RegionSuccessor> successors;
70       terminatorOp.getSuccessorRegions(operandAttributes, successors);
71       if (isContained(successors, successor)) {
72         OperandRange operands = terminatorOp.getSuccessorOperands(successor);
73         predecessorOperands.push_back(operands[operandNumber]);
74       }
75     }
76   }
77 
78   return predecessorOperands;
79 }
80 
81 /// Returns the predecessor branch operands that match `blockArg`, or nullopt if
82 /// some of the predecessor terminators do not implement the BranchOpInterface.
83 static std::optional<SmallVector<Value>>
84 getBlockPredecessorOperands(BlockArgument blockArg) {
85   Block *block = blockArg.getOwner();
86 
87   // Search the predecessor operands for all predecessor terminators.
88   SmallVector<Value> predecessorOperands;
89   for (auto it = block->pred_begin(); it != block->pred_end(); ++it) {
90     Block *predecessor = *it;
91     auto branchOp = dyn_cast<BranchOpInterface>(predecessor->getTerminator());
92     if (!branchOp)
93       return std::nullopt;
94     SuccessorOperands successorOperands =
95         branchOp.getSuccessorOperands(it.getSuccessorIndex());
96     // Store the predecessor operand if the block argument matches an operand
97     // and is not produced by the terminator.
98     if (Value operand = successorOperands[blockArg.getArgNumber()])
99       predecessorOperands.push_back(operand);
100   }
101 
102   return predecessorOperands;
103 }
104 
105 std::optional<SmallVector<Value>>
106 mlir::getControlFlowPredecessors(Value value) {
107   if (OpResult opResult = dyn_cast<OpResult>(value)) {
108     if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>())
109       return SmallVector<Value>(
110           {selectOp.getTrueValue(), selectOp.getFalseValue()});
111     auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
112     // If the interface is not implemented, there are no control flow
113     // predecessors to work with.
114     if (!regionOp)
115       return std::nullopt;
116     // Add the control flow predecessor operands to the work list.
117     RegionSuccessor region(regionOp->getResults());
118     SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
119         regionOp, region, opResult.getResultNumber());
120     return predecessorOperands;
121   }
122 
123   auto blockArg = cast<BlockArgument>(value);
124   Block *block = blockArg.getOwner();
125   // Search the region predecessor operands for structured control flow.
126   if (block->isEntryBlock()) {
127     if (auto regionBranchOp =
128             dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
129       RegionSuccessor region(blockArg.getParentRegion());
130       SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
131           regionBranchOp, region, blockArg.getArgNumber());
132       return predecessorOperands;
133     }
134     // If the interface is not implemented, there are no control flow
135     // predecessors to work with.
136     return std::nullopt;
137   }
138 
139   // Search the block predecessor operands for unstructured control flow.
140   return getBlockPredecessorOperands(blockArg);
141 }
142