1 //===- BlockSupport.h -------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a number of support types for the Block class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_BLOCKSUPPORT_H 14 #define MLIR_IR_BLOCKSUPPORT_H 15 16 #include "mlir/IR/Value.h" 17 #include "llvm/ADT/PointerUnion.h" 18 #include "llvm/ADT/ilist.h" 19 #include "llvm/ADT/ilist_node.h" 20 21 namespace mlir { 22 class Block; 23 24 //===----------------------------------------------------------------------===// 25 // BlockOperand 26 //===----------------------------------------------------------------------===// 27 28 /// A block operand represents an operand that holds a reference to a Block, 29 /// e.g. for terminator operations. 30 class BlockOperand : public IROperand<BlockOperand, Block *> { 31 public: 32 using IROperand<BlockOperand, Block *>::IROperand; 33 34 /// Provide the use list that is attached to the given block. 35 static IRObjectWithUseList<BlockOperand> *getUseList(Block *value); 36 37 /// Return which operand this is in the BlockOperand list of the Operation. 38 unsigned getOperandNumber(); 39 }; 40 41 //===----------------------------------------------------------------------===// 42 // Predecessors 43 //===----------------------------------------------------------------------===// 44 45 /// Implement a predecessor iterator for blocks. This works by walking the use 46 /// lists of the blocks. The entries on this list are the BlockOperands that 47 /// are embedded into terminator operations. From the operand, we can get the 48 /// terminator that contains it, and its parent block is the predecessor. 49 class PredecessorIterator final 50 : public llvm::mapped_iterator<ValueUseIterator<BlockOperand>, 51 Block *(*)(BlockOperand &)> { 52 static Block *unwrap(BlockOperand &value); 53 54 public: 55 /// Initializes the operand type iterator to the specified operand iterator. 56 PredecessorIterator(ValueUseIterator<BlockOperand> it) 57 : llvm::mapped_iterator<ValueUseIterator<BlockOperand>, 58 Block *(*)(BlockOperand &)>(it, &unwrap) {} 59 explicit PredecessorIterator(BlockOperand *operand) 60 : PredecessorIterator(ValueUseIterator<BlockOperand>(operand)) {} 61 62 /// Get the successor number in the predecessor terminator. 63 unsigned getSuccessorIndex() const; 64 }; 65 66 //===----------------------------------------------------------------------===// 67 // Successors 68 //===----------------------------------------------------------------------===// 69 70 /// This class implements the successor iterators for Block. 71 class SuccessorRange final 72 : public llvm::detail::indexed_accessor_range_base< 73 SuccessorRange, BlockOperand *, Block *, Block *, Block *> { 74 public: 75 using RangeBaseT::RangeBaseT; 76 SuccessorRange(); 77 SuccessorRange(Block *block); 78 SuccessorRange(Operation *term); 79 80 private: 81 /// See `llvm::detail::indexed_accessor_range_base` for details. 82 static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) { 83 return object + index; 84 } 85 /// See `llvm::detail::indexed_accessor_range_base` for details. 86 static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) { 87 return object[index].get(); 88 } 89 90 /// Allow access to `offset_base` and `dereference_iterator`. 91 friend RangeBaseT; 92 }; 93 94 //===----------------------------------------------------------------------===// 95 // BlockRange 96 //===----------------------------------------------------------------------===// 97 98 /// This class provides an abstraction over the different types of ranges over 99 /// Blocks. In many cases, this prevents the need to explicitly materialize a 100 /// SmallVector/std::vector. This class should be used in places that are not 101 /// suitable for a more derived type (e.g. ArrayRef) or a template range 102 /// parameter. 103 class BlockRange final 104 : public llvm::detail::indexed_accessor_range_base< 105 BlockRange, llvm::PointerUnion<BlockOperand *, Block *const *>, 106 Block *, Block *, Block *> { 107 public: 108 using RangeBaseT::RangeBaseT; 109 BlockRange(ArrayRef<Block *> blocks = std::nullopt); 110 BlockRange(SuccessorRange successors); 111 template <typename Arg, typename = std::enable_if_t<std::is_constructible< 112 ArrayRef<Block *>, Arg>::value>> 113 BlockRange(Arg &&arg LLVM_LIFETIME_BOUND) 114 : BlockRange(ArrayRef<Block *>(std::forward<Arg>(arg))) {} 115 BlockRange(std::initializer_list<Block *> blocks LLVM_LIFETIME_BOUND) 116 : BlockRange(ArrayRef<Block *>(blocks)) {} 117 118 private: 119 /// The owner of the range is either: 120 /// * A pointer to the first element of an array of block operands. 121 /// * A pointer to the first element of an array of Block *. 122 using OwnerT = llvm::PointerUnion<BlockOperand *, Block *const *>; 123 124 /// See `llvm::detail::indexed_accessor_range_base` for details. 125 static OwnerT offset_base(OwnerT object, ptrdiff_t index); 126 127 /// See `llvm::detail::indexed_accessor_range_base` for details. 128 static Block *dereference_iterator(OwnerT object, ptrdiff_t index); 129 130 /// Allow access to `offset_base` and `dereference_iterator`. 131 friend RangeBaseT; 132 }; 133 134 //===----------------------------------------------------------------------===// 135 // Operation Iterators 136 //===----------------------------------------------------------------------===// 137 138 namespace detail { 139 /// A utility iterator that filters out operations that are not 'OpT'. 140 template <typename OpT, typename IteratorT> 141 class op_filter_iterator 142 : public llvm::filter_iterator<IteratorT, bool (*)(Operation &)> { 143 static bool filter(Operation &op) { return llvm::isa<OpT>(op); } 144 145 public: 146 op_filter_iterator(IteratorT it, IteratorT end) 147 : llvm::filter_iterator<IteratorT, bool (*)(Operation &)>(it, end, 148 &filter) {} 149 150 /// Allow implicit conversion to the underlying iterator. 151 operator const IteratorT &() const { return this->wrapped(); } 152 }; 153 154 /// This class provides iteration over the held operations of a block for a 155 /// specific operation type. 156 template <typename OpT, typename IteratorT> 157 class op_iterator 158 : public llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>, 159 OpT (*)(Operation &)> { 160 static OpT unwrap(Operation &op) { return cast<OpT>(op); } 161 162 public: 163 /// Initializes the iterator to the specified filter iterator. 164 op_iterator(op_filter_iterator<OpT, IteratorT> it) 165 : llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>, 166 OpT (*)(Operation &)>(it, &unwrap) {} 167 168 /// Allow implicit conversion to the underlying block iterator. 169 operator const IteratorT &() const { return this->wrapped(); } 170 }; 171 } // namespace detail 172 } // namespace mlir 173 174 namespace llvm { 175 176 /// Provide support for hashing successor ranges. 177 template <> 178 struct DenseMapInfo<mlir::SuccessorRange> { 179 static mlir::SuccessorRange getEmptyKey() { 180 auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getEmptyKey(); 181 return mlir::SuccessorRange(pointer, 0); 182 } 183 static mlir::SuccessorRange getTombstoneKey() { 184 auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getTombstoneKey(); 185 return mlir::SuccessorRange(pointer, 0); 186 } 187 static unsigned getHashValue(mlir::SuccessorRange value) { 188 return llvm::hash_combine_range(value.begin(), value.end()); 189 } 190 static bool isEqual(mlir::SuccessorRange lhs, mlir::SuccessorRange rhs) { 191 if (rhs.getBase() == getEmptyKey().getBase()) 192 return lhs.getBase() == getEmptyKey().getBase(); 193 if (rhs.getBase() == getTombstoneKey().getBase()) 194 return lhs.getBase() == getTombstoneKey().getBase(); 195 return lhs == rhs; 196 } 197 }; 198 199 //===----------------------------------------------------------------------===// 200 // ilist_traits for Operation 201 //===----------------------------------------------------------------------===// 202 203 namespace ilist_detail { 204 // Explicitly define the node access for the operation list so that we can 205 // break the dependence on the Operation class in this header. This allows for 206 // operations to have trailing Regions without a circular include 207 // dependence. 208 template <> 209 struct SpecificNodeAccess< 210 typename compute_node_options<::mlir::Operation>::type> : NodeAccess { 211 protected: 212 using OptionsT = typename compute_node_options<mlir::Operation>::type; 213 using pointer = typename OptionsT::pointer; 214 using const_pointer = typename OptionsT::const_pointer; 215 using node_type = ilist_node_impl<OptionsT>; 216 217 static node_type *getNodePtr(pointer N); 218 static const node_type *getNodePtr(const_pointer N); 219 220 static pointer getValuePtr(node_type *N); 221 static const_pointer getValuePtr(const node_type *N); 222 }; 223 } // namespace ilist_detail 224 225 template <> 226 struct ilist_traits<::mlir::Operation> { 227 using Operation = ::mlir::Operation; 228 using op_iterator = simple_ilist<Operation>::iterator; 229 230 static void deleteNode(Operation *op); 231 void addNodeToList(Operation *op); 232 void removeNodeFromList(Operation *op); 233 void transferNodesFromList(ilist_traits<Operation> &otherList, 234 op_iterator first, op_iterator last); 235 236 private: 237 mlir::Block *getContainingBlock(); 238 }; 239 240 //===----------------------------------------------------------------------===// 241 // ilist_traits for Block 242 //===----------------------------------------------------------------------===// 243 244 template <> 245 struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> { 246 using Block = ::mlir::Block; 247 using block_iterator = simple_ilist<::mlir::Block>::iterator; 248 249 void addNodeToList(Block *block); 250 void removeNodeFromList(Block *block); 251 void transferNodesFromList(ilist_traits<Block> &otherList, 252 block_iterator first, block_iterator last); 253 254 private: 255 mlir::Region *getParentRegion(); 256 }; 257 258 } // namespace llvm 259 260 #endif // MLIR_IR_BLOCKSUPPORT_H 261