xref: /llvm-project/mlir/include/mlir/IR/Region.h (revision 1297c1125f9c284e0cc0f2bf50d4b7ba519f7309)
1 //===- Region.h - MLIR Region Class -----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Region class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_REGION_H
14 #define MLIR_IR_REGION_H
15 
16 #include "mlir/IR/Block.h"
17 
18 namespace mlir {
19 class TypeRange;
20 template <typename ValueRangeT>
21 class ValueTypeRange;
22 class IRMapping;
23 
24 /// This class contains a list of basic blocks and a link to the parent
25 /// operation it is attached to.
26 class Region {
27 public:
28   Region() = default;
29   explicit Region(Operation *container);
30   ~Region();
31 
32   /// Return the context this region is inserted in.  The region must have a
33   /// valid parent container.
34   MLIRContext *getContext();
35 
36   /// Return a location for this region. This is the location attached to the
37   /// parent container. The region must have a valid parent container.
38   Location getLoc();
39 
40   //===--------------------------------------------------------------------===//
41   // Block list management
42   //===--------------------------------------------------------------------===//
43 
44   using BlockListType = llvm::iplist<Block>;
45   BlockListType &getBlocks() { return blocks; }
46   Block &emplaceBlock() {
47     push_back(new Block);
48     return back();
49   }
50 
51   // Iteration over the blocks in the region.
52   using iterator = BlockListType::iterator;
53   using reverse_iterator = BlockListType::reverse_iterator;
54 
55   iterator begin() { return blocks.begin(); }
56   iterator end() { return blocks.end(); }
57   reverse_iterator rbegin() { return blocks.rbegin(); }
58   reverse_iterator rend() { return blocks.rend(); }
59 
60   bool empty() { return blocks.empty(); }
61   void push_back(Block *block) { blocks.push_back(block); }
62   void push_front(Block *block) { blocks.push_front(block); }
63 
64   Block &back() { return blocks.back(); }
65   Block &front() { return blocks.front(); }
66 
67   /// Return true if this region has exactly one block.
68   bool hasOneBlock() { return !empty() && std::next(begin()) == end(); }
69 
70   /// getSublistAccess() - Returns pointer to member of region.
71   static BlockListType Region::*getSublistAccess(Block *) {
72     return &Region::blocks;
73   }
74 
75   //===--------------------------------------------------------------------===//
76   // Argument Handling
77   //===--------------------------------------------------------------------===//
78 
79   // This is the list of arguments to the block.
80   using BlockArgListType = MutableArrayRef<BlockArgument>;
81   BlockArgListType getArguments() {
82     return empty() ? BlockArgListType() : front().getArguments();
83   }
84 
85   /// Returns the argument types of the first block within the region.
86   ValueTypeRange<BlockArgListType> getArgumentTypes();
87 
88   using args_iterator = BlockArgListType::iterator;
89   using reverse_args_iterator = BlockArgListType::reverse_iterator;
90   args_iterator args_begin() { return getArguments().begin(); }
91   args_iterator args_end() { return getArguments().end(); }
92   reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
93   reverse_args_iterator args_rend() { return getArguments().rend(); }
94 
95   bool args_empty() { return getArguments().empty(); }
96 
97   /// Add one value to the argument list.
98   BlockArgument addArgument(Type type, Location loc) {
99     return front().addArgument(type, loc);
100   }
101 
102   /// Insert one value to the position in the argument list indicated by the
103   /// given iterator. The existing arguments are shifted. The block is expected
104   /// not to have predecessors.
105   BlockArgument insertArgument(args_iterator it, Type type, Location loc) {
106     return front().insertArgument(it, type, loc);
107   }
108 
109   /// Add one argument to the argument list for each type specified in the list.
110   /// `locs` contains the locations for each of the new arguments, and must be
111   /// of equal size to `types`.
112   iterator_range<args_iterator> addArguments(TypeRange types,
113                                              ArrayRef<Location> locs);
114 
115   /// Add one value to the argument list at the specified position.
116   BlockArgument insertArgument(unsigned index, Type type, Location loc) {
117     return front().insertArgument(index, type, loc);
118   }
119 
120   /// Erase the argument at 'index' and remove it from the argument list.
121   void eraseArgument(unsigned index) { front().eraseArgument(index); }
122 
123   unsigned getNumArguments() { return getArguments().size(); }
124   BlockArgument getArgument(unsigned i) { return getArguments()[i]; }
125 
126   //===--------------------------------------------------------------------===//
127   // Operation list utilities
128   //===--------------------------------------------------------------------===//
129 
130   /// This class provides iteration over the held operations of blocks directly
131   /// within a region.
132   class OpIterator final
133       : public llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
134                                           Operation> {
135   public:
136     /// Initialize OpIterator for a region, specify `end` to return the iterator
137     /// to last operation.
138     explicit OpIterator(Region *region, bool end = false);
139 
140     using llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
141                                      Operation>::operator++;
142     OpIterator &operator++();
143     Operation *operator->() const { return &*operation; }
144     Operation &operator*() const { return *operation; }
145 
146     /// Compare this iterator with another.
147     bool operator==(const OpIterator &rhs) const {
148       return operation == rhs.operation;
149     }
150     bool operator!=(const OpIterator &rhs) const { return !(*this == rhs); }
151 
152   private:
153     void skipOverBlocksWithNoOps();
154 
155     /// The region whose operations are being iterated over.
156     Region *region;
157     /// The block of 'region' whose operations are being iterated over.
158     Region::iterator block;
159     /// The current operation within 'block'.
160     Block::iterator operation;
161   };
162 
163   /// This class provides iteration over the held operations of a region for a
164   /// specific operation type.
165   template <typename OpT>
166   using op_iterator = detail::op_iterator<OpT, OpIterator>;
167 
168   /// Return iterators that walk the operations nested directly within this
169   /// region.
170   OpIterator op_begin() { return OpIterator(this); }
171   OpIterator op_end() { return OpIterator(this, /*end=*/true); }
172   iterator_range<OpIterator> getOps() { return {op_begin(), op_end()}; }
173 
174   /// Return iterators that walk operations of type 'T' nested directly within
175   /// this region.
176   template <typename OpT>
177   op_iterator<OpT> op_begin() {
178     return detail::op_filter_iterator<OpT, OpIterator>(op_begin(), op_end());
179   }
180   template <typename OpT>
181   op_iterator<OpT> op_end() {
182     return detail::op_filter_iterator<OpT, OpIterator>(op_end(), op_end());
183   }
184   template <typename OpT>
185   iterator_range<op_iterator<OpT>> getOps() {
186     auto endIt = op_end();
187     return {detail::op_filter_iterator<OpT, OpIterator>(op_begin(), endIt),
188             detail::op_filter_iterator<OpT, OpIterator>(endIt, endIt)};
189   }
190 
191   //===--------------------------------------------------------------------===//
192   // Misc. utilities
193   //===--------------------------------------------------------------------===//
194 
195   /// Return the region containing this region or nullptr if the region is
196   /// attached to a top-level operation.
197   Region *getParentRegion();
198 
199   /// Return the parent operation this region is attached to.
200   Operation *getParentOp() { return container; }
201 
202   /// Find the first parent operation of the given type, or nullptr if there is
203   /// no ancestor operation.
204   template <typename ParentT>
205   ParentT getParentOfType() {
206     auto *region = this;
207     do {
208       if (auto parent = dyn_cast_or_null<ParentT>(region->container))
209         return parent;
210     } while ((region = region->getParentRegion()));
211     return ParentT();
212   }
213 
214   /// Return the number of this region in the parent operation.
215   unsigned getRegionNumber();
216 
217   /// Return true if this region is a proper ancestor of the `other` region.
218   bool isProperAncestor(Region *other);
219 
220   /// Return true if this region is ancestor of the `other` region.  A region
221   /// is considered as its own ancestor, use `isProperAncestor` to avoid this.
222   bool isAncestor(Region *other) {
223     return this == other || isProperAncestor(other);
224   }
225 
226   /// Clone the internal blocks from this region into dest. Any
227   /// cloned blocks are appended to the back of dest. If the mapper
228   /// contains entries for block arguments, these arguments are not included
229   /// in the respective cloned block.
230   ///
231   /// Calling this method from multiple threads is generally safe if through the
232   /// process of cloning, no new uses of 'Value's from outside the region are
233   /// created. Using the mapper, it is possible to avoid adding uses to outside
234   /// operands by remapping them to 'Value's owned by the caller thread.
235   void cloneInto(Region *dest, IRMapping &mapper);
236   /// Clone this region into 'dest' before the given position in 'dest'.
237   void cloneInto(Region *dest, Region::iterator destPos, IRMapping &mapper);
238 
239   /// Takes body of another region (that region will have no body after this
240   /// operation completes).  The current body of this region is cleared.
241   void takeBody(Region &other) {
242     dropAllReferences();
243     blocks.clear();
244     blocks.splice(blocks.end(), other.getBlocks());
245   }
246 
247   /// Returns 'block' if 'block' lies in this region, or otherwise finds the
248   /// ancestor of 'block' that lies in this region. Returns nullptr if the
249   /// latter fails.
250   Block *findAncestorBlockInRegion(Block &block);
251 
252   /// Returns 'op' if 'op' lies in this region, or otherwise finds the
253   /// ancestor of 'op' that lies in this region. Returns nullptr if the
254   /// latter fails.
255   Operation *findAncestorOpInRegion(Operation &op);
256 
257   /// Drop all operand uses from operations within this region, which is
258   /// an essential step in breaking cyclic dependences between references when
259   /// they are to be deleted.
260   void dropAllReferences();
261 
262   //===--------------------------------------------------------------------===//
263   // Walkers
264   //===--------------------------------------------------------------------===//
265 
266   /// Walk all nested operations, blocks or regions (including this region),
267   /// depending on the type of callback.
268   ///
269   /// The order in which operations, blocks or regions at the same nesting
270   /// level are visited (e.g., lexicographical or reverse lexicographical order)
271   /// is determined by `Iterator`. The walk order for enclosing operations,
272   /// blocks or regions with respect to their nested ones is specified by
273   /// `Order` (post-order by default).
274   ///
275   /// A callback on a operation or block is allowed to erase that operation or
276   /// block if either:
277   ///   * the walk is in post-order, or
278   ///   * the walk is in pre-order and the walk is skipped after the erasure.
279   ///
280   /// See Operation::walk for more details.
281   template <WalkOrder Order = WalkOrder::PostOrder,
282             typename Iterator = ForwardIterator, typename FnT,
283             typename ArgT = detail::first_argument<FnT>,
284             typename RetT = detail::walkResultType<FnT>>
285   RetT walk(FnT &&callback) {
286     if constexpr (std::is_same<ArgT, Region *>::value &&
287                   Order == WalkOrder::PreOrder) {
288       // Pre-order walk on regions: invoke the callback on this region.
289       if constexpr (std::is_same<RetT, void>::value) {
290         callback(this);
291       } else {
292         RetT result = callback(this);
293         if (result.wasSkipped())
294           return WalkResult::advance();
295         if (result.wasInterrupted())
296           return WalkResult::interrupt();
297       }
298     }
299 
300     // Walk nested operations, blocks or regions.
301     for (auto &block : *this) {
302       if constexpr (std::is_same<RetT, void>::value) {
303         block.walk<Order, Iterator>(callback);
304       } else {
305         if (block.walk<Order, Iterator>(callback).wasInterrupted())
306           return WalkResult::interrupt();
307       }
308     }
309 
310     if constexpr (std::is_same<ArgT, Region *>::value &&
311                   Order == WalkOrder::PostOrder) {
312       // Post-order walk on regions: invoke the callback on this block.
313       return callback(this);
314     }
315     if constexpr (!std::is_same<RetT, void>::value)
316       return WalkResult::advance();
317   }
318 
319   //===--------------------------------------------------------------------===//
320   // CFG view utilities
321   //===--------------------------------------------------------------------===//
322 
323   /// Displays the CFG in a window. This is for use from the debugger and
324   /// depends on Graphviz to generate the graph.
325   /// This function is defined in ViewOpGraph.cpp and only works with that
326   /// target linked.
327   void viewGraph(const Twine &regionName);
328   void viewGraph();
329 
330 private:
331   BlockListType blocks;
332 
333   /// This is the object we are part of.
334   Operation *container = nullptr;
335 };
336 
337 /// This class provides an abstraction over the different types of ranges over
338 /// Regions. In many cases, this prevents the need to explicitly materialize a
339 /// SmallVector/std::vector. This class should be used in places that are not
340 /// suitable for a more derived type (e.g. ArrayRef) or a template range
341 /// parameter.
342 class RegionRange
343     : public llvm::detail::indexed_accessor_range_base<
344           RegionRange,
345           PointerUnion<Region *, const std::unique_ptr<Region> *, Region **>,
346           Region *, Region *, Region *> {
347   /// The type representing the owner of this range. This is either an owning
348   /// list of regions, a list of region unique pointers, or a list of region
349   /// pointers.
350   using OwnerT =
351       PointerUnion<Region *, const std::unique_ptr<Region> *, Region **>;
352 
353 public:
354   using RangeBaseT::RangeBaseT;
355 
356   RegionRange(MutableArrayRef<Region> regions = std::nullopt);
357 
358   template <typename Arg, typename = std::enable_if_t<std::is_constructible<
359                               ArrayRef<std::unique_ptr<Region>>, Arg>::value>>
360   RegionRange(Arg &&arg LLVM_LIFETIME_BOUND)
361       : RegionRange(ArrayRef<std::unique_ptr<Region>>(std::forward<Arg>(arg))) {
362   }
363   template <typename Arg>
364   RegionRange(
365       Arg &&arg LLVM_LIFETIME_BOUND,
366       std::enable_if_t<std::is_constructible<ArrayRef<Region *>, Arg>::value>
367           * = nullptr)
368       : RegionRange(ArrayRef<Region *>(std::forward<Arg>(arg))) {}
369   RegionRange(ArrayRef<std::unique_ptr<Region>> regions);
370   RegionRange(ArrayRef<Region *> regions);
371 
372 private:
373   /// See `llvm::detail::indexed_accessor_range_base` for details.
374   static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
375   /// See `llvm::detail::indexed_accessor_range_base` for details.
376   static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
377 
378   /// Allow access to `offset_base` and `dereference_iterator`.
379   friend RangeBaseT;
380 };
381 
382 } // namespace mlir
383 
384 #endif // MLIR_IR_REGION_H
385