xref: /llvm-project/mlir/include/mlir/Transforms/RegionUtils.h (revision a506279e5c5d5668e66b0749c26a93d8d373931a)
1 //===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TRANSFORMS_REGIONUTILS_H_
10 #define MLIR_TRANSFORMS_REGIONUTILS_H_
11 
12 #include "mlir/IR/Region.h"
13 #include "mlir/IR/Value.h"
14 
15 #include "llvm/ADT/SetVector.h"
16 
17 namespace mlir {
18 class RewriterBase;
19 
20 /// Check if all values in the provided range are defined above the `limit`
21 /// region.  That is, if they are defined in a region that is a proper ancestor
22 /// of `limit`.
23 template <typename Range>
areValuesDefinedAbove(Range values,Region & limit)24 bool areValuesDefinedAbove(Range values, Region &limit) {
25   for (Value v : values)
26     if (!v.getParentRegion()->isProperAncestor(&limit))
27       return false;
28   return true;
29 }
30 
31 /// Replace all uses of `orig` within the given region with `replacement`.
32 void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region);
33 
34 /// Calls `callback` for each use of a value within `region` or its descendants
35 /// that was defined at the ancestors of the `limit`.
36 void visitUsedValuesDefinedAbove(Region &region, Region &limit,
37                                  function_ref<void(OpOperand *)> callback);
38 
39 /// Calls `callback` for each use of a value within any of the regions provided
40 /// that was defined in one of the ancestors.
41 void visitUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
42                                  function_ref<void(OpOperand *)> callback);
43 
44 /// Fill `values` with a list of values defined at the ancestors of the `limit`
45 /// region and used within `region` or its descendants.
46 void getUsedValuesDefinedAbove(Region &region, Region &limit,
47                                SetVector<Value> &values);
48 
49 /// Fill `values` with a list of values used within any of the regions provided
50 /// but defined in one of the ancestors.
51 void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
52                                SetVector<Value> &values);
53 
54 /// Make a region isolated from above
55 /// - Capture the values that are defined above the region and used within it.
56 /// - Append to the entry block arguments that represent the captured values
57 /// (one per captured value).
58 /// - Replace all uses within the region of the captured values with the
59 ///   newly added arguments.
60 /// - `cloneOperationIntoRegion` is a callback that allows caller to specify
61 ///   if the operation defining an `OpOperand` needs to be cloned into the
62 ///   region. Then the operands of this operation become part of the captured
63 ///   values set (unless the operations that define the operands themeselves
64 ///   are to be cloned). The cloned operations are added to the entry block
65 ///   of the region.
66 /// Return the set of captured values for the operation.
67 SmallVector<Value> makeRegionIsolatedFromAbove(
68     RewriterBase &rewriter, Region &region,
69     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
70         [](Operation *) { return false; });
71 
72 /// Run a set of structural simplifications over the given regions. This
73 /// includes transformations like unreachable block elimination, dead argument
74 /// elimination, as well as some other DCE. This function returns success if any
75 /// of the regions were simplified, failure otherwise. The provided rewriter is
76 /// used to notify callers of operation and block deletion.
77 /// Structurally similar blocks will be merged if the `mergeBlock` argument is
78 /// true. Note this can lead to merged blocks with extra arguments.
79 LogicalResult simplifyRegions(RewriterBase &rewriter,
80                               MutableArrayRef<Region> regions,
81                               bool mergeBlocks = true);
82 
83 /// Erase the unreachable blocks within the provided regions. Returns success
84 /// if any blocks were erased, failure otherwise.
85 LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
86                                      MutableArrayRef<Region> regions);
87 
88 /// This function returns success if any operations or arguments were deleted,
89 /// failure otherwise.
90 LogicalResult runRegionDCE(RewriterBase &rewriter,
91                            MutableArrayRef<Region> regions);
92 
93 } // namespace mlir
94 
95 #endif // MLIR_TRANSFORMS_REGIONUTILS_H_
96