xref: /llvm-project/mlir/lib/Reducer/ReductionNode.cpp (revision f829d62c219c6b880f0106a4a75c0f8640fcfe54)
1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the
10 // metadata for a specific generated variant within a reduction pass and are the
11 // building blocks of the reduction tree structure. A reduction tree is used to
12 // keep track of the different generated variants throughout a reduction pass in
13 // the MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Reducer/ReductionNode.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include <algorithm>
22 #include <limits>
23 
24 using namespace mlir;
25 
26 ReductionNode::ReductionNode(
27     ReductionNode *parentNode, const std::vector<Range> &ranges,
28     llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
29     /// Root node will have the parent pointer point to themselves.
30     : parent(parentNode == nullptr ? this : parentNode),
31       size(std::numeric_limits<size_t>::max()), ranges(ranges),
32       startRanges(ranges), allocator(allocator) {
33   if (parent != this)
34     if (failed(initialize(parent->getModule(), parent->getRegion())))
35       llvm_unreachable("unexpected initialization failure");
36 }
37 
38 LogicalResult ReductionNode::initialize(ModuleOp parentModule,
39                                         Region &targetRegion) {
40   // Use the mapper help us find the corresponding region after module clone.
41   BlockAndValueMapping mapper;
42   module = cast<ModuleOp>(parentModule->clone(mapper));
43   // Use the first block of targetRegion to locate the cloned region.
44   Block *block = mapper.lookup(&*targetRegion.begin());
45   region = block->getParent();
46   return success();
47 }
48 
49 /// If we haven't explored any variants from this node, we will create N
50 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
51 /// max element in `ranges` and create 2 new variants for each call.
52 ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
53   int oldNumVariant = getVariants().size();
54 
55   auto createNewNode = [this](std::vector<Range> ranges) {
56     return new (allocator.Allocate())
57         ReductionNode(this, std::move(ranges), allocator);
58   };
59 
60   // If we haven't created new variant, then we can create varients by removing
61   // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
62   // produce variants with range {{1, 3}} and {{4, 9}}.
63   if (variants.empty() && getRanges().size() > 1) {
64     for (const Range &range : getRanges()) {
65       std::vector<Range> subRanges = getRanges();
66       llvm::erase_value(subRanges, range);
67       variants.push_back(createNewNode(std::move(subRanges)));
68     }
69 
70     return getVariants().drop_front(oldNumVariant);
71   }
72 
73   // At here, we have created the type of variants mentioned above. We would
74   // like to split the max range into 2 to create 2 new variants. Continue on
75   // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
76   // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
77   // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
78   auto maxElement = std::max_element(
79       ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
80         return (lhs.second - lhs.first) > (rhs.second - rhs.first);
81       });
82 
83   // The length of range is less than 1, we can't split it to create new
84   // variant.
85   if (maxElement->second - maxElement->first <= 1)
86     return {};
87 
88   Range maxRange = *maxElement;
89   std::vector<Range> subRanges = getRanges();
90   auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
91   int half = (maxRange.first + maxRange.second) / 2;
92   *subRangesIter = std::make_pair(maxRange.first, half);
93   variants.push_back(createNewNode(subRanges));
94   *subRangesIter = std::make_pair(half, maxRange.second);
95   variants.push_back(createNewNode(std::move(subRanges)));
96 
97   auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
98   it = ranges.insert(it, std::make_pair(maxRange.first, half));
99   // Remove the range that has been split.
100   ranges.erase(it + 2);
101 
102   return getVariants().drop_front(oldNumVariant);
103 }
104 
105 void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
106   std::tie(interesting, size) = result;
107   // After applying reduction, the number of operation in the region may have
108   // changed. Non-interesting case won't be explored thus it's safe to keep it
109   // in a stale status.
110   if (interesting == Tester::Interestingness::True) {
111     // This module may has been updated. Reset the range.
112     ranges.clear();
113     ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
114   } else {
115     // Release the uninteresting module to save some memory.
116     module.release()->erase();
117   }
118 }
119 
120 ArrayRef<ReductionNode *>
121 ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
122   // Single Path: Traverses the smallest successful variant at each level until
123   // no new successful variants can be created at that level.
124   ArrayRef<ReductionNode *> variantsFromParent =
125       node->getParent()->getVariants();
126 
127   // The parent node created several variants and they may be waiting for
128   // examing interestingness. In Single Path approach, we will select the
129   // smallest variant to continue our exploration. Thus we should wait until the
130   // last variant to be examed then do the following traversal decision.
131   if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
132         return node->isInteresting() != Tester::Interestingness::Untested;
133       })) {
134     return {};
135   }
136 
137   ReductionNode *smallest = nullptr;
138   for (ReductionNode *node : variantsFromParent) {
139     if (node->isInteresting() != Tester::Interestingness::True)
140       continue;
141     if (smallest == nullptr || node->getSize() < smallest->getSize())
142       smallest = node;
143   }
144 
145   if (smallest != nullptr &&
146       smallest->getSize() < node->getParent()->getSize()) {
147     // We got a smallest one, keep traversing from this node.
148     node = smallest;
149   } else {
150     // None of these variants is interesting, let the parent node to generate
151     // more variants.
152     node = node->getParent();
153   }
154 
155   return node->generateNewVariants();
156 }
157