xref: /llvm-project/mlir/lib/Interfaces/ControlFlowInterfaces.cpp (revision dd450f08cfeb9da372cbe459058bc9ae9425f862)
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22 
SuccessorOperands(MutableOperandRange forwardedOperands)23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24     : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25 }
26 
SuccessorOperands(unsigned int producedOperandCount,MutableOperandRange forwardedOperands)27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28                                      MutableOperandRange forwardedOperands)
29     : producedOperandCount(producedOperandCount),
30       forwardedOperands(std::move(forwardedOperands)) {}
31 
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional<BlockArgument>
getBranchSuccessorArgument(const SuccessorOperands & operands,unsigned operandIndex,Block * successor)40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41                                    unsigned operandIndex, Block *successor) {
42   OperandRange forwardedOperands = operands.getForwardedOperands();
43   // Check that the operands are valid.
44   if (forwardedOperands.empty())
45     return std::nullopt;
46 
47   // Check to ensure that this operand is within the range.
48   unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49   if (operandIndex < operandsStart ||
50       operandIndex >= (operandsStart + forwardedOperands.size()))
51     return std::nullopt;
52 
53   // Index the successor.
54   unsigned argIndex =
55       operands.getProducedOperandCount() + operandIndex - operandsStart;
56   return successor->getArgument(argIndex);
57 }
58 
59 /// Verify that the given operands match those of the given successor block.
60 LogicalResult
verifyBranchSuccessorOperands(Operation * op,unsigned succNo,const SuccessorOperands & operands)61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
62                                       const SuccessorOperands &operands) {
63   // Check the count.
64   unsigned operandCount = operands.size();
65   Block *destBB = op->getSuccessor(succNo);
66   if (operandCount != destBB->getNumArguments())
67     return op->emitError() << "branch has " << operandCount
68                            << " operands for successor #" << succNo
69                            << ", but target block has "
70                            << destBB->getNumArguments();
71 
72   // Check the types.
73   for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74        ++i) {
75     if (!cast<BranchOpInterface>(op).areTypesCompatible(
76             operands[i].getType(), destBB->getArgument(i).getType()))
77       return op->emitError() << "type mismatch for bb argument #" << i
78                              << " of successor #" << succNo;
79   }
80   return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
86 
printRegionEdgeName(InFlightDiagnostic & diag,RegionBranchPoint sourceNo,RegionBranchPoint succRegionNo)87 static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
88                                                RegionBranchPoint sourceNo,
89                                                RegionBranchPoint succRegionNo) {
90   diag << "from ";
91   if (Region *region = sourceNo.getRegionOrNull())
92     diag << "Region #" << region->getRegionNumber();
93   else
94     diag << "parent operands";
95 
96   diag << " to ";
97   if (Region *region = succRegionNo.getRegionOrNull())
98     diag << "Region #" << region->getRegionNumber();
99   else
100     diag << "parent results";
101   return diag;
102 }
103 
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
107 static LogicalResult
verifyTypesAlongAllEdges(Operation * op,RegionBranchPoint sourcePoint,function_ref<FailureOr<TypeRange> (RegionBranchPoint)> getInputsTypesForRegion)108 verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
109                          function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
110                              getInputsTypesForRegion) {
111   auto regionInterface = cast<RegionBranchOpInterface>(op);
112 
113   SmallVector<RegionSuccessor, 2> successors;
114   regionInterface.getSuccessorRegions(sourcePoint, successors);
115 
116   for (RegionSuccessor &succ : successors) {
117     FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
118     if (failed(sourceTypes))
119       return failure();
120 
121     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
122     if (sourceTypes->size() != succInputsTypes.size()) {
123       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
124       return printRegionEdgeName(diag, sourcePoint, succ)
125              << ": source has " << sourceTypes->size()
126              << " operands, but target successor needs "
127              << succInputsTypes.size();
128     }
129 
130     for (const auto &typesIdx :
131          llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
132       Type sourceType = std::get<0>(typesIdx.value());
133       Type inputType = std::get<1>(typesIdx.value());
134       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
135         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
136         return printRegionEdgeName(diag, sourcePoint, succ)
137                << ": source type #" << typesIdx.index() << " " << sourceType
138                << " should match input type #" << typesIdx.index() << " "
139                << inputType;
140       }
141     }
142   }
143   return success();
144 }
145 
146 /// Verify that types match along control flow edges described the given op.
verifyTypesAlongControlFlowEdges(Operation * op)147 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
148   auto regionInterface = cast<RegionBranchOpInterface>(op);
149 
150   auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151     return regionInterface.getEntrySuccessorOperands(point).getTypes();
152   };
153 
154   // Verify types along control flow edges originating from the parent.
155   if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
156                                       inputTypesFromParent)))
157     return failure();
158 
159   auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
160     if (lhs.size() != rhs.size())
161       return false;
162     for (auto types : llvm::zip(lhs, rhs)) {
163       if (!regionInterface.areTypesCompatible(std::get<0>(types),
164                                               std::get<1>(types))) {
165         return false;
166       }
167     }
168     return true;
169   };
170 
171   // Verify types along control flow edges originating from each region.
172   for (Region &region : op->getRegions()) {
173 
174     // Since there can be multiple terminators implementing the
175     // `RegionBranchTerminatorOpInterface`, all should have the same operand
176     // types when passing them to the same region.
177 
178     SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
179     for (Block &block : region)
180       if (!block.empty())
181         if (auto terminator =
182                 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
183           regionReturnOps.push_back(terminator);
184 
185     // If there is no return-like terminator, the op itself should verify
186     // type consistency.
187     if (regionReturnOps.empty())
188       continue;
189 
190     auto inputTypesForRegion =
191         [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
192       std::optional<OperandRange> regionReturnOperands;
193       for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
194         auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
195 
196         if (!regionReturnOperands) {
197           regionReturnOperands = terminatorOperands;
198           continue;
199         }
200 
201         // Found more than one ReturnLike terminator. Make sure the operand
202         // types match with the first one.
203         if (!areTypesCompatible(regionReturnOperands->getTypes(),
204                                 terminatorOperands.getTypes())) {
205           InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
206           return printRegionEdgeName(diag, region, point)
207                  << " operands mismatch between return-like terminators";
208         }
209       }
210 
211       // All successors get the same set of operand types.
212       return TypeRange(regionReturnOperands->getTypes());
213     };
214 
215     if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
216       return failure();
217   }
218 
219   return success();
220 }
221 
222 /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223 /// this function returns "true" for a successor region. The first parameter is
224 /// the successor region. The second parameter indicates all already visited
225 /// regions.
226 using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
227 
228 /// Traverse the region graph starting at `begin`. The traversal is interrupted
229 /// if `stopCondition` evaluates to "true" for a successor region. In that case,
230 /// this function returns "true". Otherwise, if the traversal was not
231 /// interrupted, this function returns "false".
traverseRegionGraph(Region * begin,StopConditionFn stopConditionFn)232 static bool traverseRegionGraph(Region *begin,
233                                 StopConditionFn stopConditionFn) {
234   auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
235   SmallVector<bool> visited(op->getNumRegions(), false);
236   visited[begin->getRegionNumber()] = true;
237 
238   // Retrieve all successors of the region and enqueue them in the worklist.
239   SmallVector<Region *> worklist;
240   auto enqueueAllSuccessors = [&](Region *region) {
241     SmallVector<RegionSuccessor> successors;
242     op.getSuccessorRegions(region, successors);
243     for (RegionSuccessor successor : successors)
244       if (!successor.isParent())
245         worklist.push_back(successor.getSuccessor());
246   };
247   enqueueAllSuccessors(begin);
248 
249   // Process all regions in the worklist via DFS.
250   while (!worklist.empty()) {
251     Region *nextRegion = worklist.pop_back_val();
252     if (stopConditionFn(nextRegion, visited))
253       return true;
254     if (visited[nextRegion->getRegionNumber()])
255       continue;
256     visited[nextRegion->getRegionNumber()] = true;
257     enqueueAllSuccessors(nextRegion);
258   }
259 
260   return false;
261 }
262 
263 /// Return `true` if region `r` is reachable from region `begin` according to
264 /// the RegionBranchOpInterface (by taking a branch).
isRegionReachable(Region * begin,Region * r)265 static bool isRegionReachable(Region *begin, Region *r) {
266   assert(begin->getParentOp() == r->getParentOp() &&
267          "expected that both regions belong to the same op");
268   return traverseRegionGraph(begin,
269                              [&](Region *nextRegion, ArrayRef<bool> visited) {
270                                // Interrupt traversal if `r` was reached.
271                                return nextRegion == r;
272                              });
273 }
274 
275 /// Return `true` if `a` and `b` are in mutually exclusive regions.
276 ///
277 /// 1. Find the first common of `a` and `b` (ancestor) that implements
278 ///    RegionBranchOpInterface.
279 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
280 ///    contained.
281 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
282 ///    mutually exclusive if they are not reachable from each other as per
283 ///    RegionBranchOpInterface::getSuccessorRegions.
insideMutuallyExclusiveRegions(Operation * a,Operation * b)284 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
285   assert(a && "expected non-empty operation");
286   assert(b && "expected non-empty operation");
287 
288   auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
289   while (branchOp) {
290     // Check if b is inside branchOp. (We already know that a is.)
291     if (!branchOp->isProperAncestor(b)) {
292       // Check next enclosing RegionBranchOpInterface.
293       branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
294       continue;
295     }
296 
297     // b is contained in branchOp. Retrieve the regions in which `a` and `b`
298     // are contained.
299     Region *regionA = nullptr, *regionB = nullptr;
300     for (Region &r : branchOp->getRegions()) {
301       if (r.findAncestorOpInRegion(*a)) {
302         assert(!regionA && "already found a region for a");
303         regionA = &r;
304       }
305       if (r.findAncestorOpInRegion(*b)) {
306         assert(!regionB && "already found a region for b");
307         regionB = &r;
308       }
309     }
310     assert(regionA && regionB && "could not find region of op");
311 
312     // `a` and `b` are in mutually exclusive regions if both regions are
313     // distinct and neither region is reachable from the other region.
314     return regionA != regionB && !isRegionReachable(regionA, regionB) &&
315            !isRegionReachable(regionB, regionA);
316   }
317 
318   // Could not find a common RegionBranchOpInterface among a's and b's
319   // ancestors.
320   return false;
321 }
322 
isRepetitiveRegion(unsigned index)323 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
324   Region *region = &getOperation()->getRegion(index);
325   return isRegionReachable(region, region);
326 }
327 
hasLoop()328 bool RegionBranchOpInterface::hasLoop() {
329   SmallVector<RegionSuccessor> entryRegions;
330   getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
331   for (RegionSuccessor successor : entryRegions)
332     if (!successor.isParent() &&
333         traverseRegionGraph(successor.getSuccessor(),
334                             [](Region *nextRegion, ArrayRef<bool> visited) {
335                               // Interrupt traversal if the region was already
336                               // visited.
337                               return visited[nextRegion->getRegionNumber()];
338                             }))
339       return true;
340   return false;
341 }
342 
getEnclosingRepetitiveRegion(Operation * op)343 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
344   while (Region *region = op->getParentRegion()) {
345     op = region->getParentOp();
346     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
347       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
348         return region;
349   }
350   return nullptr;
351 }
352 
getEnclosingRepetitiveRegion(Value value)353 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
354   Region *region = value.getParentRegion();
355   while (region) {
356     Operation *op = region->getParentOp();
357     if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
358       if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
359         return region;
360     region = op->getParentRegion();
361   }
362   return nullptr;
363 }
364