xref: /llvm-project/mlir/include/mlir/IR/BlockSupport.h (revision 1297c1125f9c284e0cc0f2bf50d4b7ba519f7309)
1 //===- BlockSupport.h -------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a number of support types for the Block class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_BLOCKSUPPORT_H
14 #define MLIR_IR_BLOCKSUPPORT_H
15 
16 #include "mlir/IR/Value.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/ilist.h"
19 #include "llvm/ADT/ilist_node.h"
20 
21 namespace mlir {
22 class Block;
23 
24 //===----------------------------------------------------------------------===//
25 // BlockOperand
26 //===----------------------------------------------------------------------===//
27 
28 /// A block operand represents an operand that holds a reference to a Block,
29 /// e.g. for terminator operations.
30 class BlockOperand : public IROperand<BlockOperand, Block *> {
31 public:
32   using IROperand<BlockOperand, Block *>::IROperand;
33 
34   /// Provide the use list that is attached to the given block.
35   static IRObjectWithUseList<BlockOperand> *getUseList(Block *value);
36 
37   /// Return which operand this is in the BlockOperand list of the Operation.
38   unsigned getOperandNumber();
39 };
40 
41 //===----------------------------------------------------------------------===//
42 // Predecessors
43 //===----------------------------------------------------------------------===//
44 
45 /// Implement a predecessor iterator for blocks. This works by walking the use
46 /// lists of the blocks. The entries on this list are the BlockOperands that
47 /// are embedded into terminator operations. From the operand, we can get the
48 /// terminator that contains it, and its parent block is the predecessor.
49 class PredecessorIterator final
50     : public llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
51                                    Block *(*)(BlockOperand &)> {
52   static Block *unwrap(BlockOperand &value);
53 
54 public:
55   /// Initializes the operand type iterator to the specified operand iterator.
56   PredecessorIterator(ValueUseIterator<BlockOperand> it)
57       : llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
58                               Block *(*)(BlockOperand &)>(it, &unwrap) {}
59   explicit PredecessorIterator(BlockOperand *operand)
60       : PredecessorIterator(ValueUseIterator<BlockOperand>(operand)) {}
61 
62   /// Get the successor number in the predecessor terminator.
63   unsigned getSuccessorIndex() const;
64 };
65 
66 //===----------------------------------------------------------------------===//
67 // Successors
68 //===----------------------------------------------------------------------===//
69 
70 /// This class implements the successor iterators for Block.
71 class SuccessorRange final
72     : public llvm::detail::indexed_accessor_range_base<
73           SuccessorRange, BlockOperand *, Block *, Block *, Block *> {
74 public:
75   using RangeBaseT::RangeBaseT;
76   SuccessorRange();
77   SuccessorRange(Block *block);
78   SuccessorRange(Operation *term);
79 
80 private:
81   /// See `llvm::detail::indexed_accessor_range_base` for details.
82   static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) {
83     return object + index;
84   }
85   /// See `llvm::detail::indexed_accessor_range_base` for details.
86   static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) {
87     return object[index].get();
88   }
89 
90   /// Allow access to `offset_base` and `dereference_iterator`.
91   friend RangeBaseT;
92 };
93 
94 //===----------------------------------------------------------------------===//
95 // BlockRange
96 //===----------------------------------------------------------------------===//
97 
98 /// This class provides an abstraction over the different types of ranges over
99 /// Blocks. In many cases, this prevents the need to explicitly materialize a
100 /// SmallVector/std::vector. This class should be used in places that are not
101 /// suitable for a more derived type (e.g. ArrayRef) or a template range
102 /// parameter.
103 class BlockRange final
104     : public llvm::detail::indexed_accessor_range_base<
105           BlockRange, llvm::PointerUnion<BlockOperand *, Block *const *>,
106           Block *, Block *, Block *> {
107 public:
108   using RangeBaseT::RangeBaseT;
109   BlockRange(ArrayRef<Block *> blocks = std::nullopt);
110   BlockRange(SuccessorRange successors);
111   template <typename Arg, typename = std::enable_if_t<std::is_constructible<
112                               ArrayRef<Block *>, Arg>::value>>
113   BlockRange(Arg &&arg LLVM_LIFETIME_BOUND)
114       : BlockRange(ArrayRef<Block *>(std::forward<Arg>(arg))) {}
115   BlockRange(std::initializer_list<Block *> blocks LLVM_LIFETIME_BOUND)
116       : BlockRange(ArrayRef<Block *>(blocks)) {}
117 
118 private:
119   /// The owner of the range is either:
120   /// * A pointer to the first element of an array of block operands.
121   /// * A pointer to the first element of an array of Block *.
122   using OwnerT = llvm::PointerUnion<BlockOperand *, Block *const *>;
123 
124   /// See `llvm::detail::indexed_accessor_range_base` for details.
125   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
126 
127   /// See `llvm::detail::indexed_accessor_range_base` for details.
128   static Block *dereference_iterator(OwnerT object, ptrdiff_t index);
129 
130   /// Allow access to `offset_base` and `dereference_iterator`.
131   friend RangeBaseT;
132 };
133 
134 //===----------------------------------------------------------------------===//
135 // Operation Iterators
136 //===----------------------------------------------------------------------===//
137 
138 namespace detail {
139 /// A utility iterator that filters out operations that are not 'OpT'.
140 template <typename OpT, typename IteratorT>
141 class op_filter_iterator
142     : public llvm::filter_iterator<IteratorT, bool (*)(Operation &)> {
143   static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
144 
145 public:
146   op_filter_iterator(IteratorT it, IteratorT end)
147       : llvm::filter_iterator<IteratorT, bool (*)(Operation &)>(it, end,
148                                                                 &filter) {}
149 
150   /// Allow implicit conversion to the underlying iterator.
151   operator const IteratorT &() const { return this->wrapped(); }
152 };
153 
154 /// This class provides iteration over the held operations of a block for a
155 /// specific operation type.
156 template <typename OpT, typename IteratorT>
157 class op_iterator
158     : public llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
159                                    OpT (*)(Operation &)> {
160   static OpT unwrap(Operation &op) { return cast<OpT>(op); }
161 
162 public:
163   /// Initializes the iterator to the specified filter iterator.
164   op_iterator(op_filter_iterator<OpT, IteratorT> it)
165       : llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
166                               OpT (*)(Operation &)>(it, &unwrap) {}
167 
168   /// Allow implicit conversion to the underlying block iterator.
169   operator const IteratorT &() const { return this->wrapped(); }
170 };
171 } // namespace detail
172 } // namespace mlir
173 
174 namespace llvm {
175 
176 /// Provide support for hashing successor ranges.
177 template <>
178 struct DenseMapInfo<mlir::SuccessorRange> {
179   static mlir::SuccessorRange getEmptyKey() {
180     auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getEmptyKey();
181     return mlir::SuccessorRange(pointer, 0);
182   }
183   static mlir::SuccessorRange getTombstoneKey() {
184     auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getTombstoneKey();
185     return mlir::SuccessorRange(pointer, 0);
186   }
187   static unsigned getHashValue(mlir::SuccessorRange value) {
188     return llvm::hash_combine_range(value.begin(), value.end());
189   }
190   static bool isEqual(mlir::SuccessorRange lhs, mlir::SuccessorRange rhs) {
191     if (rhs.getBase() == getEmptyKey().getBase())
192       return lhs.getBase() == getEmptyKey().getBase();
193     if (rhs.getBase() == getTombstoneKey().getBase())
194       return lhs.getBase() == getTombstoneKey().getBase();
195     return lhs == rhs;
196   }
197 };
198 
199 //===----------------------------------------------------------------------===//
200 // ilist_traits for Operation
201 //===----------------------------------------------------------------------===//
202 
203 namespace ilist_detail {
204 // Explicitly define the node access for the operation list so that we can
205 // break the dependence on the Operation class in this header. This allows for
206 // operations to have trailing Regions without a circular include
207 // dependence.
208 template <>
209 struct SpecificNodeAccess<
210     typename compute_node_options<::mlir::Operation>::type> : NodeAccess {
211 protected:
212   using OptionsT = typename compute_node_options<mlir::Operation>::type;
213   using pointer = typename OptionsT::pointer;
214   using const_pointer = typename OptionsT::const_pointer;
215   using node_type = ilist_node_impl<OptionsT>;
216 
217   static node_type *getNodePtr(pointer N);
218   static const node_type *getNodePtr(const_pointer N);
219 
220   static pointer getValuePtr(node_type *N);
221   static const_pointer getValuePtr(const node_type *N);
222 };
223 } // namespace ilist_detail
224 
225 template <>
226 struct ilist_traits<::mlir::Operation> {
227   using Operation = ::mlir::Operation;
228   using op_iterator = simple_ilist<Operation>::iterator;
229 
230   static void deleteNode(Operation *op);
231   void addNodeToList(Operation *op);
232   void removeNodeFromList(Operation *op);
233   void transferNodesFromList(ilist_traits<Operation> &otherList,
234                              op_iterator first, op_iterator last);
235 
236 private:
237   mlir::Block *getContainingBlock();
238 };
239 
240 //===----------------------------------------------------------------------===//
241 // ilist_traits for Block
242 //===----------------------------------------------------------------------===//
243 
244 template <>
245 struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> {
246   using Block = ::mlir::Block;
247   using block_iterator = simple_ilist<::mlir::Block>::iterator;
248 
249   void addNodeToList(Block *block);
250   void removeNodeFromList(Block *block);
251   void transferNodesFromList(ilist_traits<Block> &otherList,
252                              block_iterator first, block_iterator last);
253 
254 private:
255   mlir::Region *getParentRegion();
256 };
257 
258 } // namespace llvm
259 
260 #endif // MLIR_IR_BLOCKSUPPORT_H
261