xref: /llvm-project/mlir/include/mlir/Reducer/ReductionNode.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the metadata
10 // for a specific generated variant within a reduction pass and are the building
11 // blocks of the reduction tree structure. A reduction tree is used to keep
12 // track of the different generated variants throughout a reduction pass in the
13 // MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_REDUCER_REDUCTIONNODE_H
18 #define MLIR_REDUCER_REDUCTIONNODE_H
19 
20 #include <queue>
21 #include <vector>
22 
23 #include "mlir/IR/OwningOpRef.h"
24 #include "mlir/Reducer/Tester.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/Support/Allocator.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 
29 namespace mlir {
30 
31 class ModuleOp;
32 class Region;
33 
34 /// Defines the traversal method options to be used in the reduction tree
35 /// traversal.
36 enum TraversalMode { SinglePath, Backtrack, MultiPath };
37 
38 /// ReductionTreePass will build a reduction tree during module reduction and
39 /// the ReductionNode represents the vertex of the tree. A ReductionNode records
40 /// the information such as the reduced module, how this node is reduced from
41 /// the parent node, etc. This information will be used to construct a reduction
42 /// path to reduce the certain module.
43 class ReductionNode {
44 public:
45   template <TraversalMode mode>
46   class iterator;
47 
48   using Range = std::pair<int, int>;
49 
50   ReductionNode(ReductionNode *parent, const std::vector<Range> &range,
51                 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
52 
getParent()53   ReductionNode *getParent() const { return parent; }
54 
55   /// If the ReductionNode hasn't been tested the interestingness, it'll be the
56   /// same module as the one in the parent node. Otherwise, the returned module
57   /// will have been applied certain reduction strategies. Note that it's not
58   /// necessary to be an interesting case or a reduced module (has smaller size
59   /// than parent's).
getModule()60   ModuleOp getModule() const { return module.get(); }
61 
62   /// Return the region we're reducing.
getRegion()63   Region &getRegion() const { return *region; }
64 
65   /// Return the size of the module.
getSize()66   size_t getSize() const { return size; }
67 
68   /// Returns true if the module exhibits the interesting behavior.
isInteresting()69   Tester::Interestingness isInteresting() const { return interesting; }
70 
71   /// Return the range information that how this node is reduced from the parent
72   /// node.
getStartRanges()73   ArrayRef<Range> getStartRanges() const { return startRanges; }
74 
75   /// Return the range set we are using to generate variants.
getRanges()76   ArrayRef<Range> getRanges() const { return ranges; }
77 
78   /// Return the generated variants(the child nodes).
getVariants()79   ArrayRef<ReductionNode *> getVariants() const { return variants; }
80 
81   /// Split the ranges and generate new variants.
82   ArrayRef<ReductionNode *> generateNewVariants();
83 
84   /// Update the interestingness result from tester.
85   void update(std::pair<Tester::Interestingness, size_t> result);
86 
87   /// Each Reduction Node contains a copy of module for applying rewrite
88   /// patterns. In addition, we only apply rewrite patterns in a certain region.
89   /// In init(), we will duplicate the module from parent node and locate the
90   /// corresponding region.
91   LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
92 
93 private:
94   /// A custom BFS iterator. The difference between
95   /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
96   /// We may explore more neighbors at certain node if we didn't find interested
97   /// event. As a result, we defer pushing adjacent nodes until poping the last
98   /// visited node. The graph exploration strategy will be put in
99   /// getNeighbors().
100   ///
101   /// Subclass BaseIterator and implement traversal strategy in getNeighbors().
102   template <typename T>
103   class BaseIterator {
104   public:
BaseIterator(ReductionNode * node)105     BaseIterator(ReductionNode *node) { visitQueue.push(node); }
106     BaseIterator(const BaseIterator &) = default;
107     BaseIterator() = default;
108 
end()109     static BaseIterator end() { return BaseIterator(); }
110 
111     bool operator==(const BaseIterator &i) {
112       return visitQueue == i.visitQueue;
113     }
114     bool operator!=(const BaseIterator &i) { return !(*this == i); }
115 
116     BaseIterator &operator++() {
117       ReductionNode *top = visitQueue.front();
118       visitQueue.pop();
119       for (ReductionNode *node : getNeighbors(top))
120         visitQueue.push(node);
121       return *this;
122     }
123 
124     BaseIterator operator++(int) {
125       BaseIterator tmp = *this;
126       ++*this;
127       return tmp;
128     }
129 
130     ReductionNode &operator*() const { return *(visitQueue.front()); }
131     ReductionNode *operator->() const { return visitQueue.front(); }
132 
133   protected:
getNeighbors(ReductionNode * node)134     ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
135       return static_cast<T *>(this)->getNeighbors(node);
136     }
137 
138   private:
139     std::queue<ReductionNode *> visitQueue;
140   };
141 
142   /// This is a copy of module from parent node. All the reducer patterns will
143   /// be applied to this instance.
144   OwningOpRef<ModuleOp> module;
145 
146   /// The region of certain operation we're reducing in the module
147   Region *region = nullptr;
148 
149   /// The node we are reduced from. It means we will be in variants of parent
150   /// node.
151   ReductionNode *parent = nullptr;
152 
153   /// The size of module after applying the reducer patterns with range
154   /// constraints. This is only valid while the interestingness has been tested.
155   size_t size = 0;
156 
157   /// This is true if the module has been evaluated and it exhibits the
158   /// interesting behavior.
159   Tester::Interestingness interesting = Tester::Interestingness::Untested;
160 
161   /// `ranges` represents the selected subset of operations in the region. We
162   /// implicitly number each operation in the region and ReductionTreePass will
163   /// apply reducer patterns on the operation falls into the `ranges`. We will
164   /// generate new ReductionNode with subset of `ranges` to see if we can do
165   /// further reduction. we may split the element in the `ranges` so that we can
166   /// have more subset variants from `ranges`.
167   /// Note that after applying the reducer patterns the number of operation in
168   /// the region may have changed, we need to update the `ranges` after that.
169   std::vector<Range> ranges;
170 
171   /// `startRanges` records the ranges of operations selected from the parent
172   /// node to produce this ReductionNode. It can be used to construct the
173   /// reduction path from the root. I.e., if we apply the same reducer patterns
174   /// and `startRanges` selection on the parent region, we will get the same
175   /// module as this node.
176   const std::vector<Range> startRanges;
177 
178   /// This points to the child variants that were created using this node as a
179   /// starting point.
180   std::vector<ReductionNode *> variants;
181 
182   llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator;
183 };
184 
185 // Specialized iterator for SinglePath traversal
186 template <>
187 class ReductionNode::iterator<SinglePath>
188     : public BaseIterator<iterator<SinglePath>> {
189   friend BaseIterator<iterator<SinglePath>>;
190   using BaseIterator::BaseIterator;
191   ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
192 };
193 
194 } // namespace mlir
195 
196 #endif // MLIR_REDUCER_REDUCTIONNODE_H
197