1 //===- ValueRange.h - Indexed Value-Iterators Range Classes -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the ValueRange related classes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_VALUERANGE_H 14 #define MLIR_IR_VALUERANGE_H 15 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/ADT/PointerUnion.h" 20 #include "llvm/ADT/Sequence.h" 21 #include <optional> 22 23 namespace mlir { 24 class ValueRange; 25 template <typename ValueRangeT> 26 class ValueTypeRange; 27 class TypeRangeRange; 28 template <typename ValueIteratorT> 29 class ValueTypeIterator; 30 class OperandRangeRange; 31 class MutableOperandRangeRange; 32 33 //===----------------------------------------------------------------------===// 34 // Operation Value-Iterators 35 //===----------------------------------------------------------------------===// 36 37 //===----------------------------------------------------------------------===// 38 // OperandRange 39 40 /// This class implements the operand iterators for the Operation class. 41 class OperandRange final : public llvm::detail::indexed_accessor_range_base< 42 OperandRange, OpOperand *, Value, Value, Value> { 43 public: 44 using RangeBaseT::RangeBaseT; 45 46 /// Returns the types of the values within this range. 47 using type_iterator = ValueTypeIterator<iterator>; 48 using type_range = ValueTypeRange<OperandRange>; 49 type_range getTypes() const; 50 type_range getType() const; 51 52 /// Return the operand index of the first element of this range. The range 53 /// must not be empty. 54 unsigned getBeginOperandIndex() const; 55 56 /// Split this range into a set of contiguous subranges using the given 57 /// elements attribute, which contains the sizes of the sub ranges. 58 OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const; 59 60 private: 61 /// See `llvm::detail::indexed_accessor_range_base` for details. 62 static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { 63 return object + index; 64 } 65 /// See `llvm::detail::indexed_accessor_range_base` for details. 66 static Value dereference_iterator(OpOperand *object, ptrdiff_t index) { 67 return object[index].get(); 68 } 69 70 /// Allow access to `offset_base` and `dereference_iterator`. 71 friend RangeBaseT; 72 }; 73 74 //===----------------------------------------------------------------------===// 75 // OperandRangeRange 76 77 /// This class represents a contiguous range of operand ranges, e.g. from a 78 /// VariadicOfVariadic operand group. 79 class OperandRangeRange final 80 : public llvm::indexed_accessor_range< 81 OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange, 82 OperandRange, OperandRange> { 83 using OwnerT = std::pair<OpOperand *, Attribute>; 84 using RangeBaseT = 85 llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange, 86 OperandRange, OperandRange>; 87 88 public: 89 using RangeBaseT::RangeBaseT; 90 91 /// Returns the range of types of the values within this range. 92 TypeRangeRange getTypes() const; 93 TypeRangeRange getType() const; 94 95 /// Construct a range given a parent set of operands, and an I32 elements 96 /// attribute containing the sizes of the sub ranges. 97 OperandRangeRange(OperandRange operands, Attribute operandSegments); 98 99 /// Flatten all of the sub ranges into a single contiguous operand range. 100 OperandRange join() const; 101 102 private: 103 /// See `llvm::indexed_accessor_range` for details. 104 static OperandRange dereference(const OwnerT &object, ptrdiff_t index); 105 106 /// Allow access to `dereference_iterator`. 107 friend RangeBaseT; 108 }; 109 110 //===----------------------------------------------------------------------===// 111 // MutableOperandRange 112 113 /// This class provides a mutable adaptor for a range of operands. It allows for 114 /// setting, inserting, and erasing operands from the given range. 115 class MutableOperandRange { 116 public: 117 /// A pair of a named attribute corresponding to an operand segment attribute, 118 /// and the index within that attribute. The attribute should correspond to a 119 /// dense i32 array attr. 120 using OperandSegment = std::pair<unsigned, NamedAttribute>; 121 122 /// Construct a new mutable range from the given operand, operand start index, 123 /// and range length. `operandSegments` is an optional set of operand segments 124 /// to be updated when mutating the operand list. 125 MutableOperandRange(Operation *owner, unsigned start, unsigned length, 126 ArrayRef<OperandSegment> operandSegments = std::nullopt); 127 MutableOperandRange(Operation *owner); 128 129 /// Construct a new mutable range for the given OpOperand. 130 MutableOperandRange(OpOperand &opOperand); 131 132 /// Slice this range into a sub range, with the additional operand segment. 133 MutableOperandRange 134 slice(unsigned subStart, unsigned subLen, 135 std::optional<OperandSegment> segment = std::nullopt) const; 136 137 /// Append the given values to the range. 138 void append(ValueRange values); 139 140 /// Assign this range to the given values. 141 void assign(ValueRange values); 142 143 /// Assign the range to the given value. 144 void assign(Value value); 145 146 /// Erase the operands within the given sub-range. 147 void erase(unsigned subStart, unsigned subLen = 1); 148 149 /// Clear this range and erase all of the operands. 150 void clear(); 151 152 /// Returns the current size of the range. 153 unsigned size() const { return length; } 154 155 /// Returns if the current range is empty. 156 bool empty() const { return size() == 0; } 157 158 /// Explicit conversion to an OperandRange. 159 OperandRange getAsOperandRange() const; 160 161 /// Allow implicit conversion to an OperandRange. 162 operator OperandRange() const; 163 164 /// Allow implicit conversion to a MutableArrayRef. 165 operator MutableArrayRef<OpOperand>() const; 166 167 /// Returns the owning operation. 168 Operation *getOwner() const { return owner; } 169 170 /// Split this range into a set of contiguous subranges using the given 171 /// elements attribute, which contains the sizes of the sub ranges. 172 MutableOperandRangeRange split(NamedAttribute segmentSizes) const; 173 174 /// Returns the OpOperand at the given index. 175 OpOperand &operator[](unsigned index) const; 176 177 /// Iterators enumerate OpOperands. 178 MutableArrayRef<OpOperand>::iterator begin() const; 179 MutableArrayRef<OpOperand>::iterator end() const; 180 181 private: 182 /// Update the length of this range to the one provided. 183 void updateLength(unsigned newLength); 184 185 /// The owning operation of this range. 186 Operation *owner; 187 188 /// The start index of the operand range within the owner operand list, and 189 /// the length starting from `start`. 190 unsigned start, length; 191 192 /// Optional set of operand segments that should be updated when mutating the 193 /// length of this range. 194 SmallVector<OperandSegment, 1> operandSegments; 195 }; 196 197 //===----------------------------------------------------------------------===// 198 // MutableOperandRangeRange 199 200 /// This class represents a contiguous range of mutable operand ranges, e.g. 201 /// from a VariadicOfVariadic operand group. 202 class MutableOperandRangeRange final 203 : public llvm::indexed_accessor_range< 204 MutableOperandRangeRange, 205 std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange, 206 MutableOperandRange, MutableOperandRange> { 207 using OwnerT = std::pair<MutableOperandRange, NamedAttribute>; 208 using RangeBaseT = 209 llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT, 210 MutableOperandRange, MutableOperandRange, 211 MutableOperandRange>; 212 213 public: 214 using RangeBaseT::RangeBaseT; 215 216 /// Construct a range given a parent set of operands, and an I32 tensor 217 /// elements attribute containing the sizes of the sub ranges. 218 MutableOperandRangeRange(const MutableOperandRange &operands, 219 NamedAttribute operandSegmentAttr); 220 221 /// Flatten all of the sub ranges into a single contiguous mutable operand 222 /// range. 223 MutableOperandRange join() const; 224 225 /// Allow implicit conversion to an OperandRangeRange. 226 operator OperandRangeRange() const; 227 228 private: 229 /// See `llvm::indexed_accessor_range` for details. 230 static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index); 231 232 /// Allow access to `dereference_iterator`. 233 friend RangeBaseT; 234 }; 235 236 //===----------------------------------------------------------------------===// 237 // ResultRange 238 239 /// This class implements the result iterators for the Operation class. 240 class ResultRange final 241 : public llvm::detail::indexed_accessor_range_base< 242 ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> { 243 public: 244 using RangeBaseT::RangeBaseT; 245 ResultRange(OpResult result); 246 247 //===--------------------------------------------------------------------===// 248 // Types 249 //===--------------------------------------------------------------------===// 250 251 /// Returns the types of the values within this range. 252 using type_iterator = ValueTypeIterator<iterator>; 253 using type_range = ValueTypeRange<ResultRange>; 254 type_range getTypes() const; 255 type_range getType() const; 256 257 //===--------------------------------------------------------------------===// 258 // Uses 259 //===--------------------------------------------------------------------===// 260 261 class UseIterator; 262 using use_iterator = UseIterator; 263 using use_range = iterator_range<use_iterator>; 264 265 /// Returns a range of all uses of results within this range, which is useful 266 /// for iterating over all uses. 267 use_range getUses() const; 268 use_iterator use_begin() const; 269 use_iterator use_end() const; 270 271 /// Returns true if no results in this range have uses. 272 bool use_empty() const { 273 return llvm::all_of(*this, 274 [](OpResult result) { return result.use_empty(); }); 275 } 276 277 /// Replace all uses of results of this range with the provided 'values'. The 278 /// size of `values` must match the size of this range. 279 template <typename ValuesT> 280 std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value> 281 replaceAllUsesWith(ValuesT &&values) { 282 assert(static_cast<size_t>(std::distance(values.begin(), values.end())) == 283 size() && 284 "expected 'values' to correspond 1-1 with the number of results"); 285 286 for (auto it : llvm::zip(*this, values)) 287 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 288 } 289 290 /// Replace all uses of results of this range with results of 'op'. 291 void replaceAllUsesWith(Operation *op); 292 293 /// Replace uses of results of this range with the provided 'values' if the 294 /// given callback returns true. The size of `values` must match the size of 295 /// this range. 296 template <typename ValuesT> 297 std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value> 298 replaceUsesWithIf(ValuesT &&values, 299 function_ref<bool(OpOperand &)> shouldReplace) { 300 assert(static_cast<size_t>(std::distance(values.begin(), values.end())) == 301 size() && 302 "expected 'values' to correspond 1-1 with the number of results"); 303 304 for (auto it : llvm::zip(*this, values)) 305 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace); 306 } 307 308 /// Replace uses of results of this range with results of `op` if the given 309 /// callback returns true. 310 void replaceUsesWithIf(Operation *op, 311 function_ref<bool(OpOperand &)> shouldReplace); 312 313 //===--------------------------------------------------------------------===// 314 // Users 315 //===--------------------------------------------------------------------===// 316 317 using user_iterator = ValueUserIterator<use_iterator, OpOperand>; 318 using user_range = iterator_range<user_iterator>; 319 320 /// Returns a range of all users. 321 user_range getUsers(); 322 user_iterator user_begin(); 323 user_iterator user_end(); 324 325 private: 326 /// See `llvm::detail::indexed_accessor_range_base` for details. 327 static detail::OpResultImpl *offset_base(detail::OpResultImpl *object, 328 ptrdiff_t index) { 329 return object->getNextResultAtOffset(index); 330 } 331 /// See `llvm::detail::indexed_accessor_range_base` for details. 332 static OpResult dereference_iterator(detail::OpResultImpl *object, 333 ptrdiff_t index) { 334 return offset_base(object, index); 335 } 336 337 /// Allow access to `offset_base` and `dereference_iterator`. 338 friend RangeBaseT; 339 }; 340 341 /// This class implements a use iterator for a range of operation results. 342 /// This iterates over all uses of all results within the given result range. 343 class ResultRange::UseIterator final 344 : public llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag, 345 OpOperand> { 346 public: 347 /// Initialize the UseIterator. Specify `end` to return iterator to last 348 /// use, otherwise this is an iterator to the first use. 349 explicit UseIterator(ResultRange results, bool end = false); 350 351 using llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag, 352 OpOperand>::operator++; 353 UseIterator &operator++(); 354 OpOperand *operator->() const { return use.getOperand(); } 355 OpOperand &operator*() const { return *use.getOperand(); } 356 357 bool operator==(const UseIterator &rhs) const { return use == rhs.use; } 358 bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); } 359 360 private: 361 void skipOverResultsWithNoUsers(); 362 363 /// The range of results being iterated over. 364 ResultRange::iterator it, endIt; 365 /// The use of the result. 366 Value::use_iterator use; 367 }; 368 369 //===----------------------------------------------------------------------===// 370 // ValueRange 371 372 /// This class provides an abstraction over the different types of ranges over 373 /// Values. In many cases, this prevents the need to explicitly materialize a 374 /// SmallVector/std::vector. This class should be used in places that are not 375 /// suitable for a more derived type (e.g. ArrayRef) or a template range 376 /// parameter. 377 class ValueRange final 378 : public llvm::detail::indexed_accessor_range_base< 379 ValueRange, 380 PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>, 381 Value, Value, Value> { 382 public: 383 /// The type representing the owner of a ValueRange. This is either a list of 384 /// values, operands, or results. 385 using OwnerT = 386 PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>; 387 388 using RangeBaseT::RangeBaseT; 389 390 template <typename Arg, 391 typename = std::enable_if_t< 392 std::is_constructible<ArrayRef<Value>, Arg>::value && 393 !std::is_convertible<Arg, Value>::value>> 394 ValueRange(Arg &&arg LLVM_LIFETIME_BOUND) 395 : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {} 396 ValueRange(const Value &value LLVM_LIFETIME_BOUND) 397 : ValueRange(&value, /*count=*/1) {} 398 ValueRange(const std::initializer_list<Value> &values LLVM_LIFETIME_BOUND) 399 : ValueRange(ArrayRef<Value>(values)) {} 400 ValueRange(iterator_range<OperandRange::iterator> values) 401 : ValueRange(OperandRange(values)) {} 402 ValueRange(iterator_range<ResultRange::iterator> values) 403 : ValueRange(ResultRange(values)) {} 404 ValueRange(ArrayRef<BlockArgument> values) 405 : ValueRange(ArrayRef<Value>(values.data(), values.size())) {} 406 ValueRange(ArrayRef<Value> values = std::nullopt); 407 ValueRange(OperandRange values); 408 ValueRange(ResultRange values); 409 410 /// Returns the types of the values within this range. 411 using type_iterator = ValueTypeIterator<iterator>; 412 using type_range = ValueTypeRange<ValueRange>; 413 type_range getTypes() const; 414 type_range getType() const; 415 416 private: 417 /// See `llvm::detail::indexed_accessor_range_base` for details. 418 static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); 419 /// See `llvm::detail::indexed_accessor_range_base` for details. 420 static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index); 421 422 /// Allow access to `offset_base` and `dereference_iterator`. 423 friend RangeBaseT; 424 }; 425 426 } // namespace mlir 427 428 #endif // MLIR_IR_VALUERANGE_H 429