1 //===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the reduction nodes which are used to track of the metadata 10 // for a specific generated variant within a reduction pass and are the building 11 // blocks of the reduction tree structure. A reduction tree is used to keep 12 // track of the different generated variants throughout a reduction pass in the 13 // MLIR Reduce tool. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #ifndef MLIR_REDUCER_REDUCTIONNODE_H 18 #define MLIR_REDUCER_REDUCTIONNODE_H 19 20 #include <queue> 21 #include <vector> 22 23 #include "mlir/IR/OwningOpRef.h" 24 #include "mlir/Reducer/Tester.h" 25 #include "llvm/ADT/ArrayRef.h" 26 #include "llvm/Support/Allocator.h" 27 #include "llvm/Support/ToolOutputFile.h" 28 29 namespace mlir { 30 31 class ModuleOp; 32 class Region; 33 34 /// Defines the traversal method options to be used in the reduction tree 35 /// traversal. 36 enum TraversalMode { SinglePath, Backtrack, MultiPath }; 37 38 /// ReductionTreePass will build a reduction tree during module reduction and 39 /// the ReductionNode represents the vertex of the tree. A ReductionNode records 40 /// the information such as the reduced module, how this node is reduced from 41 /// the parent node, etc. This information will be used to construct a reduction 42 /// path to reduce the certain module. 43 class ReductionNode { 44 public: 45 template <TraversalMode mode> 46 class iterator; 47 48 using Range = std::pair<int, int>; 49 50 ReductionNode(ReductionNode *parent, const std::vector<Range> &range, 51 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator); 52 getParent()53 ReductionNode *getParent() const { return parent; } 54 55 /// If the ReductionNode hasn't been tested the interestingness, it'll be the 56 /// same module as the one in the parent node. Otherwise, the returned module 57 /// will have been applied certain reduction strategies. Note that it's not 58 /// necessary to be an interesting case or a reduced module (has smaller size 59 /// than parent's). getModule()60 ModuleOp getModule() const { return module.get(); } 61 62 /// Return the region we're reducing. getRegion()63 Region &getRegion() const { return *region; } 64 65 /// Return the size of the module. getSize()66 size_t getSize() const { return size; } 67 68 /// Returns true if the module exhibits the interesting behavior. isInteresting()69 Tester::Interestingness isInteresting() const { return interesting; } 70 71 /// Return the range information that how this node is reduced from the parent 72 /// node. getStartRanges()73 ArrayRef<Range> getStartRanges() const { return startRanges; } 74 75 /// Return the range set we are using to generate variants. getRanges()76 ArrayRef<Range> getRanges() const { return ranges; } 77 78 /// Return the generated variants(the child nodes). getVariants()79 ArrayRef<ReductionNode *> getVariants() const { return variants; } 80 81 /// Split the ranges and generate new variants. 82 ArrayRef<ReductionNode *> generateNewVariants(); 83 84 /// Update the interestingness result from tester. 85 void update(std::pair<Tester::Interestingness, size_t> result); 86 87 /// Each Reduction Node contains a copy of module for applying rewrite 88 /// patterns. In addition, we only apply rewrite patterns in a certain region. 89 /// In init(), we will duplicate the module from parent node and locate the 90 /// corresponding region. 91 LogicalResult initialize(ModuleOp parentModule, Region &parentRegion); 92 93 private: 94 /// A custom BFS iterator. The difference between 95 /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. 96 /// We may explore more neighbors at certain node if we didn't find interested 97 /// event. As a result, we defer pushing adjacent nodes until poping the last 98 /// visited node. The graph exploration strategy will be put in 99 /// getNeighbors(). 100 /// 101 /// Subclass BaseIterator and implement traversal strategy in getNeighbors(). 102 template <typename T> 103 class BaseIterator { 104 public: BaseIterator(ReductionNode * node)105 BaseIterator(ReductionNode *node) { visitQueue.push(node); } 106 BaseIterator(const BaseIterator &) = default; 107 BaseIterator() = default; 108 end()109 static BaseIterator end() { return BaseIterator(); } 110 111 bool operator==(const BaseIterator &i) { 112 return visitQueue == i.visitQueue; 113 } 114 bool operator!=(const BaseIterator &i) { return !(*this == i); } 115 116 BaseIterator &operator++() { 117 ReductionNode *top = visitQueue.front(); 118 visitQueue.pop(); 119 for (ReductionNode *node : getNeighbors(top)) 120 visitQueue.push(node); 121 return *this; 122 } 123 124 BaseIterator operator++(int) { 125 BaseIterator tmp = *this; 126 ++*this; 127 return tmp; 128 } 129 130 ReductionNode &operator*() const { return *(visitQueue.front()); } 131 ReductionNode *operator->() const { return visitQueue.front(); } 132 133 protected: getNeighbors(ReductionNode * node)134 ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) { 135 return static_cast<T *>(this)->getNeighbors(node); 136 } 137 138 private: 139 std::queue<ReductionNode *> visitQueue; 140 }; 141 142 /// This is a copy of module from parent node. All the reducer patterns will 143 /// be applied to this instance. 144 OwningOpRef<ModuleOp> module; 145 146 /// The region of certain operation we're reducing in the module 147 Region *region = nullptr; 148 149 /// The node we are reduced from. It means we will be in variants of parent 150 /// node. 151 ReductionNode *parent = nullptr; 152 153 /// The size of module after applying the reducer patterns with range 154 /// constraints. This is only valid while the interestingness has been tested. 155 size_t size = 0; 156 157 /// This is true if the module has been evaluated and it exhibits the 158 /// interesting behavior. 159 Tester::Interestingness interesting = Tester::Interestingness::Untested; 160 161 /// `ranges` represents the selected subset of operations in the region. We 162 /// implicitly number each operation in the region and ReductionTreePass will 163 /// apply reducer patterns on the operation falls into the `ranges`. We will 164 /// generate new ReductionNode with subset of `ranges` to see if we can do 165 /// further reduction. we may split the element in the `ranges` so that we can 166 /// have more subset variants from `ranges`. 167 /// Note that after applying the reducer patterns the number of operation in 168 /// the region may have changed, we need to update the `ranges` after that. 169 std::vector<Range> ranges; 170 171 /// `startRanges` records the ranges of operations selected from the parent 172 /// node to produce this ReductionNode. It can be used to construct the 173 /// reduction path from the root. I.e., if we apply the same reducer patterns 174 /// and `startRanges` selection on the parent region, we will get the same 175 /// module as this node. 176 const std::vector<Range> startRanges; 177 178 /// This points to the child variants that were created using this node as a 179 /// starting point. 180 std::vector<ReductionNode *> variants; 181 182 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator; 183 }; 184 185 // Specialized iterator for SinglePath traversal 186 template <> 187 class ReductionNode::iterator<SinglePath> 188 : public BaseIterator<iterator<SinglePath>> { 189 friend BaseIterator<iterator<SinglePath>>; 190 using BaseIterator::BaseIterator; 191 ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node); 192 }; 193 194 } // namespace mlir 195 196 #endif // MLIR_REDUCER_REDUCTIONNODE_H 197