xref: /llvm-project/mlir/include/mlir/IR/ValueRange.h (revision 1297c1125f9c284e0cc0f2bf50d4b7ba519f7309)
1 //===- ValueRange.h - Indexed Value-Iterators Range Classes -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the ValueRange related classes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VALUERANGE_H
14 #define MLIR_IR_VALUERANGE_H
15 
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/ADT/PointerUnion.h"
20 #include "llvm/ADT/Sequence.h"
21 #include <optional>
22 
23 namespace mlir {
24 class ValueRange;
25 template <typename ValueRangeT>
26 class ValueTypeRange;
27 class TypeRangeRange;
28 template <typename ValueIteratorT>
29 class ValueTypeIterator;
30 class OperandRangeRange;
31 class MutableOperandRangeRange;
32 
33 //===----------------------------------------------------------------------===//
34 // Operation Value-Iterators
35 //===----------------------------------------------------------------------===//
36 
37 //===----------------------------------------------------------------------===//
38 // OperandRange
39 
40 /// This class implements the operand iterators for the Operation class.
41 class OperandRange final : public llvm::detail::indexed_accessor_range_base<
42                                OperandRange, OpOperand *, Value, Value, Value> {
43 public:
44   using RangeBaseT::RangeBaseT;
45 
46   /// Returns the types of the values within this range.
47   using type_iterator = ValueTypeIterator<iterator>;
48   using type_range = ValueTypeRange<OperandRange>;
49   type_range getTypes() const;
50   type_range getType() const;
51 
52   /// Return the operand index of the first element of this range. The range
53   /// must not be empty.
54   unsigned getBeginOperandIndex() const;
55 
56   /// Split this range into a set of contiguous subranges using the given
57   /// elements attribute, which contains the sizes of the sub ranges.
58   OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const;
59 
60 private:
61   /// See `llvm::detail::indexed_accessor_range_base` for details.
62   static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
63     return object + index;
64   }
65   /// See `llvm::detail::indexed_accessor_range_base` for details.
66   static Value dereference_iterator(OpOperand *object, ptrdiff_t index) {
67     return object[index].get();
68   }
69 
70   /// Allow access to `offset_base` and `dereference_iterator`.
71   friend RangeBaseT;
72 };
73 
74 //===----------------------------------------------------------------------===//
75 // OperandRangeRange
76 
77 /// This class represents a contiguous range of operand ranges, e.g. from a
78 /// VariadicOfVariadic operand group.
79 class OperandRangeRange final
80     : public llvm::indexed_accessor_range<
81           OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange,
82           OperandRange, OperandRange> {
83   using OwnerT = std::pair<OpOperand *, Attribute>;
84   using RangeBaseT =
85       llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange,
86                                    OperandRange, OperandRange>;
87 
88 public:
89   using RangeBaseT::RangeBaseT;
90 
91   /// Returns the range of types of the values within this range.
92   TypeRangeRange getTypes() const;
93   TypeRangeRange getType() const;
94 
95   /// Construct a range given a parent set of operands, and an I32 elements
96   /// attribute containing the sizes of the sub ranges.
97   OperandRangeRange(OperandRange operands, Attribute operandSegments);
98 
99   /// Flatten all of the sub ranges into a single contiguous operand range.
100   OperandRange join() const;
101 
102 private:
103   /// See `llvm::indexed_accessor_range` for details.
104   static OperandRange dereference(const OwnerT &object, ptrdiff_t index);
105 
106   /// Allow access to `dereference_iterator`.
107   friend RangeBaseT;
108 };
109 
110 //===----------------------------------------------------------------------===//
111 // MutableOperandRange
112 
113 /// This class provides a mutable adaptor for a range of operands. It allows for
114 /// setting, inserting, and erasing operands from the given range.
115 class MutableOperandRange {
116 public:
117   /// A pair of a named attribute corresponding to an operand segment attribute,
118   /// and the index within that attribute. The attribute should correspond to a
119   /// dense i32 array attr.
120   using OperandSegment = std::pair<unsigned, NamedAttribute>;
121 
122   /// Construct a new mutable range from the given operand, operand start index,
123   /// and range length. `operandSegments` is an optional set of operand segments
124   /// to be updated when mutating the operand list.
125   MutableOperandRange(Operation *owner, unsigned start, unsigned length,
126                       ArrayRef<OperandSegment> operandSegments = std::nullopt);
127   MutableOperandRange(Operation *owner);
128 
129   /// Construct a new mutable range for the given OpOperand.
130   MutableOperandRange(OpOperand &opOperand);
131 
132   /// Slice this range into a sub range, with the additional operand segment.
133   MutableOperandRange
134   slice(unsigned subStart, unsigned subLen,
135         std::optional<OperandSegment> segment = std::nullopt) const;
136 
137   /// Append the given values to the range.
138   void append(ValueRange values);
139 
140   /// Assign this range to the given values.
141   void assign(ValueRange values);
142 
143   /// Assign the range to the given value.
144   void assign(Value value);
145 
146   /// Erase the operands within the given sub-range.
147   void erase(unsigned subStart, unsigned subLen = 1);
148 
149   /// Clear this range and erase all of the operands.
150   void clear();
151 
152   /// Returns the current size of the range.
153   unsigned size() const { return length; }
154 
155   /// Returns if the current range is empty.
156   bool empty() const { return size() == 0; }
157 
158   /// Explicit conversion to an OperandRange.
159   OperandRange getAsOperandRange() const;
160 
161   /// Allow implicit conversion to an OperandRange.
162   operator OperandRange() const;
163 
164   /// Allow implicit conversion to a MutableArrayRef.
165   operator MutableArrayRef<OpOperand>() const;
166 
167   /// Returns the owning operation.
168   Operation *getOwner() const { return owner; }
169 
170   /// Split this range into a set of contiguous subranges using the given
171   /// elements attribute, which contains the sizes of the sub ranges.
172   MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
173 
174   /// Returns the OpOperand at the given index.
175   OpOperand &operator[](unsigned index) const;
176 
177   /// Iterators enumerate OpOperands.
178   MutableArrayRef<OpOperand>::iterator begin() const;
179   MutableArrayRef<OpOperand>::iterator end() const;
180 
181 private:
182   /// Update the length of this range to the one provided.
183   void updateLength(unsigned newLength);
184 
185   /// The owning operation of this range.
186   Operation *owner;
187 
188   /// The start index of the operand range within the owner operand list, and
189   /// the length starting from `start`.
190   unsigned start, length;
191 
192   /// Optional set of operand segments that should be updated when mutating the
193   /// length of this range.
194   SmallVector<OperandSegment, 1> operandSegments;
195 };
196 
197 //===----------------------------------------------------------------------===//
198 // MutableOperandRangeRange
199 
200 /// This class represents a contiguous range of mutable operand ranges, e.g.
201 /// from a VariadicOfVariadic operand group.
202 class MutableOperandRangeRange final
203     : public llvm::indexed_accessor_range<
204           MutableOperandRangeRange,
205           std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange,
206           MutableOperandRange, MutableOperandRange> {
207   using OwnerT = std::pair<MutableOperandRange, NamedAttribute>;
208   using RangeBaseT =
209       llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT,
210                                    MutableOperandRange, MutableOperandRange,
211                                    MutableOperandRange>;
212 
213 public:
214   using RangeBaseT::RangeBaseT;
215 
216   /// Construct a range given a parent set of operands, and an I32 tensor
217   /// elements attribute containing the sizes of the sub ranges.
218   MutableOperandRangeRange(const MutableOperandRange &operands,
219                            NamedAttribute operandSegmentAttr);
220 
221   /// Flatten all of the sub ranges into a single contiguous mutable operand
222   /// range.
223   MutableOperandRange join() const;
224 
225   /// Allow implicit conversion to an OperandRangeRange.
226   operator OperandRangeRange() const;
227 
228 private:
229   /// See `llvm::indexed_accessor_range` for details.
230   static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index);
231 
232   /// Allow access to `dereference_iterator`.
233   friend RangeBaseT;
234 };
235 
236 //===----------------------------------------------------------------------===//
237 // ResultRange
238 
239 /// This class implements the result iterators for the Operation class.
240 class ResultRange final
241     : public llvm::detail::indexed_accessor_range_base<
242           ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> {
243 public:
244   using RangeBaseT::RangeBaseT;
245   ResultRange(OpResult result);
246 
247   //===--------------------------------------------------------------------===//
248   // Types
249   //===--------------------------------------------------------------------===//
250 
251   /// Returns the types of the values within this range.
252   using type_iterator = ValueTypeIterator<iterator>;
253   using type_range = ValueTypeRange<ResultRange>;
254   type_range getTypes() const;
255   type_range getType() const;
256 
257   //===--------------------------------------------------------------------===//
258   // Uses
259   //===--------------------------------------------------------------------===//
260 
261   class UseIterator;
262   using use_iterator = UseIterator;
263   using use_range = iterator_range<use_iterator>;
264 
265   /// Returns a range of all uses of results within this range, which is useful
266   /// for iterating over all uses.
267   use_range getUses() const;
268   use_iterator use_begin() const;
269   use_iterator use_end() const;
270 
271   /// Returns true if no results in this range have uses.
272   bool use_empty() const {
273     return llvm::all_of(*this,
274                         [](OpResult result) { return result.use_empty(); });
275   }
276 
277   /// Replace all uses of results of this range with the provided 'values'. The
278   /// size of `values` must match the size of this range.
279   template <typename ValuesT>
280   std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
281   replaceAllUsesWith(ValuesT &&values) {
282     assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
283                size() &&
284            "expected 'values' to correspond 1-1 with the number of results");
285 
286     for (auto it : llvm::zip(*this, values))
287       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
288   }
289 
290   /// Replace all uses of results of this range with results of 'op'.
291   void replaceAllUsesWith(Operation *op);
292 
293   /// Replace uses of results of this range with the provided 'values' if the
294   /// given callback returns true. The size of `values` must match the size of
295   /// this range.
296   template <typename ValuesT>
297   std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
298   replaceUsesWithIf(ValuesT &&values,
299                     function_ref<bool(OpOperand &)> shouldReplace) {
300     assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
301                size() &&
302            "expected 'values' to correspond 1-1 with the number of results");
303 
304     for (auto it : llvm::zip(*this, values))
305       std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace);
306   }
307 
308   /// Replace uses of results of this range with results of `op` if the given
309   /// callback returns true.
310   void replaceUsesWithIf(Operation *op,
311                          function_ref<bool(OpOperand &)> shouldReplace);
312 
313   //===--------------------------------------------------------------------===//
314   // Users
315   //===--------------------------------------------------------------------===//
316 
317   using user_iterator = ValueUserIterator<use_iterator, OpOperand>;
318   using user_range = iterator_range<user_iterator>;
319 
320   /// Returns a range of all users.
321   user_range getUsers();
322   user_iterator user_begin();
323   user_iterator user_end();
324 
325 private:
326   /// See `llvm::detail::indexed_accessor_range_base` for details.
327   static detail::OpResultImpl *offset_base(detail::OpResultImpl *object,
328                                            ptrdiff_t index) {
329     return object->getNextResultAtOffset(index);
330   }
331   /// See `llvm::detail::indexed_accessor_range_base` for details.
332   static OpResult dereference_iterator(detail::OpResultImpl *object,
333                                        ptrdiff_t index) {
334     return offset_base(object, index);
335   }
336 
337   /// Allow access to `offset_base` and `dereference_iterator`.
338   friend RangeBaseT;
339 };
340 
341 /// This class implements a use iterator for a range of operation results.
342 /// This iterates over all uses of all results within the given result range.
343 class ResultRange::UseIterator final
344     : public llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag,
345                                         OpOperand> {
346 public:
347   /// Initialize the UseIterator. Specify `end` to return iterator to last
348   /// use, otherwise this is an iterator to the first use.
349   explicit UseIterator(ResultRange results, bool end = false);
350 
351   using llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag,
352                                    OpOperand>::operator++;
353   UseIterator &operator++();
354   OpOperand *operator->() const { return use.getOperand(); }
355   OpOperand &operator*() const { return *use.getOperand(); }
356 
357   bool operator==(const UseIterator &rhs) const { return use == rhs.use; }
358   bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); }
359 
360 private:
361   void skipOverResultsWithNoUsers();
362 
363   /// The range of results being iterated over.
364   ResultRange::iterator it, endIt;
365   /// The use of the result.
366   Value::use_iterator use;
367 };
368 
369 //===----------------------------------------------------------------------===//
370 // ValueRange
371 
372 /// This class provides an abstraction over the different types of ranges over
373 /// Values. In many cases, this prevents the need to explicitly materialize a
374 /// SmallVector/std::vector. This class should be used in places that are not
375 /// suitable for a more derived type (e.g. ArrayRef) or a template range
376 /// parameter.
377 class ValueRange final
378     : public llvm::detail::indexed_accessor_range_base<
379           ValueRange,
380           PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
381           Value, Value, Value> {
382 public:
383   /// The type representing the owner of a ValueRange. This is either a list of
384   /// values, operands, or results.
385   using OwnerT =
386       PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
387 
388   using RangeBaseT::RangeBaseT;
389 
390   template <typename Arg,
391             typename = std::enable_if_t<
392                 std::is_constructible<ArrayRef<Value>, Arg>::value &&
393                 !std::is_convertible<Arg, Value>::value>>
394   ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
395       : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
396   ValueRange(const Value &value LLVM_LIFETIME_BOUND)
397       : ValueRange(&value, /*count=*/1) {}
398   ValueRange(const std::initializer_list<Value> &values LLVM_LIFETIME_BOUND)
399       : ValueRange(ArrayRef<Value>(values)) {}
400   ValueRange(iterator_range<OperandRange::iterator> values)
401       : ValueRange(OperandRange(values)) {}
402   ValueRange(iterator_range<ResultRange::iterator> values)
403       : ValueRange(ResultRange(values)) {}
404   ValueRange(ArrayRef<BlockArgument> values)
405       : ValueRange(ArrayRef<Value>(values.data(), values.size())) {}
406   ValueRange(ArrayRef<Value> values = std::nullopt);
407   ValueRange(OperandRange values);
408   ValueRange(ResultRange values);
409 
410   /// Returns the types of the values within this range.
411   using type_iterator = ValueTypeIterator<iterator>;
412   using type_range = ValueTypeRange<ValueRange>;
413   type_range getTypes() const;
414   type_range getType() const;
415 
416 private:
417   /// See `llvm::detail::indexed_accessor_range_base` for details.
418   static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
419   /// See `llvm::detail::indexed_accessor_range_base` for details.
420   static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index);
421 
422   /// Allow access to `offset_base` and `dereference_iterator`.
423   friend RangeBaseT;
424 };
425 
426 } // namespace mlir
427 
428 #endif // MLIR_IR_VALUERANGE_H
429