xref: /llvm-project/mlir/lib/Reducer/ReductionNode.cpp (revision fab2bb8bfda865bd438dee981d7be7df8017b76d)
1c484c7ddSChia-hung Duan //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2c484c7ddSChia-hung Duan //
3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information.
5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c484c7ddSChia-hung Duan //
7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
8c484c7ddSChia-hung Duan //
9c484c7ddSChia-hung Duan // This file defines the reduction nodes which are used to track of the
10c484c7ddSChia-hung Duan // metadata for a specific generated variant within a reduction pass and are the
11c484c7ddSChia-hung Duan // building blocks of the reduction tree structure. A reduction tree is used to
12c484c7ddSChia-hung Duan // keep track of the different generated variants throughout a reduction pass in
13c484c7ddSChia-hung Duan // the MLIR Reduce tool.
14c484c7ddSChia-hung Duan //
15c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
16c484c7ddSChia-hung Duan 
17c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionNode.h"
184d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
19c484c7ddSChia-hung Duan #include "llvm/ADT/STLExtras.h"
20c484c7ddSChia-hung Duan 
21c484c7ddSChia-hung Duan #include <algorithm>
22c484c7ddSChia-hung Duan #include <limits>
23c484c7ddSChia-hung Duan 
24c484c7ddSChia-hung Duan using namespace mlir;
25c484c7ddSChia-hung Duan 
ReductionNode(ReductionNode * parentNode,const std::vector<Range> & ranges,llvm::SpecificBumpPtrAllocator<ReductionNode> & allocator)26c484c7ddSChia-hung Duan ReductionNode::ReductionNode(
271fc096afSMehdi Amini     ReductionNode *parentNode, const std::vector<Range> &ranges,
28c484c7ddSChia-hung Duan     llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
29c484c7ddSChia-hung Duan     /// Root node will have the parent pointer point to themselves.
30c484c7ddSChia-hung Duan     : parent(parentNode == nullptr ? this : parentNode),
31f829d62cSMehdi Amini       size(std::numeric_limits<size_t>::max()), ranges(ranges),
32c484c7ddSChia-hung Duan       startRanges(ranges), allocator(allocator) {
33c484c7ddSChia-hung Duan   if (parent != this)
34c484c7ddSChia-hung Duan     if (failed(initialize(parent->getModule(), parent->getRegion())))
35c484c7ddSChia-hung Duan       llvm_unreachable("unexpected initialization failure");
36c484c7ddSChia-hung Duan }
37c484c7ddSChia-hung Duan 
initialize(ModuleOp parentModule,Region & targetRegion)38c484c7ddSChia-hung Duan LogicalResult ReductionNode::initialize(ModuleOp parentModule,
39c484c7ddSChia-hung Duan                                         Region &targetRegion) {
40c484c7ddSChia-hung Duan   // Use the mapper help us find the corresponding region after module clone.
414d67b278SJeff Niu   IRMapping mapper;
42c484c7ddSChia-hung Duan   module = cast<ModuleOp>(parentModule->clone(mapper));
43c484c7ddSChia-hung Duan   // Use the first block of targetRegion to locate the cloned region.
44c484c7ddSChia-hung Duan   Block *block = mapper.lookup(&*targetRegion.begin());
45c484c7ddSChia-hung Duan   region = block->getParent();
46c484c7ddSChia-hung Duan   return success();
47c484c7ddSChia-hung Duan }
48c484c7ddSChia-hung Duan 
49c484c7ddSChia-hung Duan /// If we haven't explored any variants from this node, we will create N
50c484c7ddSChia-hung Duan /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
51c484c7ddSChia-hung Duan /// max element in `ranges` and create 2 new variants for each call.
generateNewVariants()52c484c7ddSChia-hung Duan ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
53c484c7ddSChia-hung Duan   int oldNumVariant = getVariants().size();
54c484c7ddSChia-hung Duan 
554f415216SMehdi Amini   auto createNewNode = [this](const std::vector<Range> &ranges) {
56337c937dSMehdi Amini     return new (allocator.Allocate()) ReductionNode(this, ranges, allocator);
57c484c7ddSChia-hung Duan   };
58c484c7ddSChia-hung Duan 
59c484c7ddSChia-hung Duan   // If we haven't created new variant, then we can create varients by removing
60c484c7ddSChia-hung Duan   // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
61c484c7ddSChia-hung Duan   // produce variants with range {{1, 3}} and {{4, 9}}.
625a1f6077SMehdi Amini   if (variants.empty() && getRanges().size() > 1) {
63c484c7ddSChia-hung Duan     for (const Range &range : getRanges()) {
64c484c7ddSChia-hung Duan       std::vector<Range> subRanges = getRanges();
65f9306f6dSKazu Hirata       llvm::erase(subRanges, range);
6660d13b85SMehdi Amini       variants.push_back(createNewNode(subRanges));
67c484c7ddSChia-hung Duan     }
68c484c7ddSChia-hung Duan 
69c484c7ddSChia-hung Duan     return getVariants().drop_front(oldNumVariant);
70c484c7ddSChia-hung Duan   }
71c484c7ddSChia-hung Duan 
72c484c7ddSChia-hung Duan   // At here, we have created the type of variants mentioned above. We would
73c484c7ddSChia-hung Duan   // like to split the max range into 2 to create 2 new variants. Continue on
74c484c7ddSChia-hung Duan   // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
75c484c7ddSChia-hung Duan   // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
76c484c7ddSChia-hung Duan   // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
77*fab2bb8bSJustin Lebar   auto maxElement =
78*fab2bb8bSJustin Lebar       llvm::max_element(ranges, [](const Range &lhs, const Range &rhs) {
79c484c7ddSChia-hung Duan         return (lhs.second - lhs.first) > (rhs.second - rhs.first);
80c484c7ddSChia-hung Duan       });
81c484c7ddSChia-hung Duan 
82c484c7ddSChia-hung Duan   // The length of range is less than 1, we can't split it to create new
83c484c7ddSChia-hung Duan   // variant.
84c484c7ddSChia-hung Duan   if (maxElement->second - maxElement->first <= 1)
85c484c7ddSChia-hung Duan     return {};
86c484c7ddSChia-hung Duan 
87c484c7ddSChia-hung Duan   Range maxRange = *maxElement;
88c484c7ddSChia-hung Duan   std::vector<Range> subRanges = getRanges();
89c484c7ddSChia-hung Duan   auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
90c484c7ddSChia-hung Duan   int half = (maxRange.first + maxRange.second) / 2;
91c484c7ddSChia-hung Duan   *subRangesIter = std::make_pair(maxRange.first, half);
92c484c7ddSChia-hung Duan   variants.push_back(createNewNode(subRanges));
93c484c7ddSChia-hung Duan   *subRangesIter = std::make_pair(half, maxRange.second);
9460d13b85SMehdi Amini   variants.push_back(createNewNode(subRanges));
95c484c7ddSChia-hung Duan 
96c484c7ddSChia-hung Duan   auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
97c484c7ddSChia-hung Duan   it = ranges.insert(it, std::make_pair(maxRange.first, half));
98c484c7ddSChia-hung Duan   // Remove the range that has been split.
99c484c7ddSChia-hung Duan   ranges.erase(it + 2);
100c484c7ddSChia-hung Duan 
101c484c7ddSChia-hung Duan   return getVariants().drop_front(oldNumVariant);
102c484c7ddSChia-hung Duan }
103c484c7ddSChia-hung Duan 
update(std::pair<Tester::Interestingness,size_t> result)104c484c7ddSChia-hung Duan void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
105c484c7ddSChia-hung Duan   std::tie(interesting, size) = result;
106c484c7ddSChia-hung Duan   // After applying reduction, the number of operation in the region may have
107c484c7ddSChia-hung Duan   // changed. Non-interesting case won't be explored thus it's safe to keep it
108c484c7ddSChia-hung Duan   // in a stale status.
109c484c7ddSChia-hung Duan   if (interesting == Tester::Interestingness::True) {
110c484c7ddSChia-hung Duan     // This module may has been updated. Reset the range.
111c484c7ddSChia-hung Duan     ranges.clear();
112e5639b3fSMehdi Amini     ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
113ba913b8dSChia-hung Duan   } else {
114ba913b8dSChia-hung Duan     // Release the uninteresting module to save some memory.
115ba913b8dSChia-hung Duan     module.release()->erase();
116c484c7ddSChia-hung Duan   }
117c484c7ddSChia-hung Duan }
118c484c7ddSChia-hung Duan 
119c484c7ddSChia-hung Duan ArrayRef<ReductionNode *>
getNeighbors(ReductionNode * node)120c484c7ddSChia-hung Duan ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
121c484c7ddSChia-hung Duan   // Single Path: Traverses the smallest successful variant at each level until
122c484c7ddSChia-hung Duan   // no new successful variants can be created at that level.
123c484c7ddSChia-hung Duan   ArrayRef<ReductionNode *> variantsFromParent =
124c484c7ddSChia-hung Duan       node->getParent()->getVariants();
125c484c7ddSChia-hung Duan 
126c484c7ddSChia-hung Duan   // The parent node created several variants and they may be waiting for
127c484c7ddSChia-hung Duan   // examing interestingness. In Single Path approach, we will select the
128c484c7ddSChia-hung Duan   // smallest variant to continue our exploration. Thus we should wait until the
129c484c7ddSChia-hung Duan   // last variant to be examed then do the following traversal decision.
130c484c7ddSChia-hung Duan   if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
131c484c7ddSChia-hung Duan         return node->isInteresting() != Tester::Interestingness::Untested;
132c484c7ddSChia-hung Duan       })) {
133c484c7ddSChia-hung Duan     return {};
134c484c7ddSChia-hung Duan   }
135c484c7ddSChia-hung Duan 
136c484c7ddSChia-hung Duan   ReductionNode *smallest = nullptr;
137c484c7ddSChia-hung Duan   for (ReductionNode *node : variantsFromParent) {
138c484c7ddSChia-hung Duan     if (node->isInteresting() != Tester::Interestingness::True)
139c484c7ddSChia-hung Duan       continue;
140c484c7ddSChia-hung Duan     if (smallest == nullptr || node->getSize() < smallest->getSize())
141c484c7ddSChia-hung Duan       smallest = node;
142c484c7ddSChia-hung Duan   }
143c484c7ddSChia-hung Duan 
144c484c7ddSChia-hung Duan   if (smallest != nullptr &&
145c484c7ddSChia-hung Duan       smallest->getSize() < node->getParent()->getSize()) {
146c484c7ddSChia-hung Duan     // We got a smallest one, keep traversing from this node.
147c484c7ddSChia-hung Duan     node = smallest;
148c484c7ddSChia-hung Duan   } else {
149c484c7ddSChia-hung Duan     // None of these variants is interesting, let the parent node to generate
150c484c7ddSChia-hung Duan     // more variants.
151c484c7ddSChia-hung Duan     node = node->getParent();
152c484c7ddSChia-hung Duan   }
153c484c7ddSChia-hung Duan 
154c484c7ddSChia-hung Duan   return node->generateNewVariants();
155c484c7ddSChia-hung Duan }
156