1c484c7ddSChia-hung Duan //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2c484c7ddSChia-hung Duan //
3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information.
5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c484c7ddSChia-hung Duan //
7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
8c484c7ddSChia-hung Duan //
9c484c7ddSChia-hung Duan // This file defines the reduction nodes which are used to track of the
10c484c7ddSChia-hung Duan // metadata for a specific generated variant within a reduction pass and are the
11c484c7ddSChia-hung Duan // building blocks of the reduction tree structure. A reduction tree is used to
12c484c7ddSChia-hung Duan // keep track of the different generated variants throughout a reduction pass in
13c484c7ddSChia-hung Duan // the MLIR Reduce tool.
14c484c7ddSChia-hung Duan //
15c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
16c484c7ddSChia-hung Duan
17c484c7ddSChia-hung Duan #include "mlir/Reducer/ReductionNode.h"
184d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
19c484c7ddSChia-hung Duan #include "llvm/ADT/STLExtras.h"
20c484c7ddSChia-hung Duan
21c484c7ddSChia-hung Duan #include <algorithm>
22c484c7ddSChia-hung Duan #include <limits>
23c484c7ddSChia-hung Duan
24c484c7ddSChia-hung Duan using namespace mlir;
25c484c7ddSChia-hung Duan
ReductionNode(ReductionNode * parentNode,const std::vector<Range> & ranges,llvm::SpecificBumpPtrAllocator<ReductionNode> & allocator)26c484c7ddSChia-hung Duan ReductionNode::ReductionNode(
271fc096afSMehdi Amini ReductionNode *parentNode, const std::vector<Range> &ranges,
28c484c7ddSChia-hung Duan llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
29c484c7ddSChia-hung Duan /// Root node will have the parent pointer point to themselves.
30c484c7ddSChia-hung Duan : parent(parentNode == nullptr ? this : parentNode),
31f829d62cSMehdi Amini size(std::numeric_limits<size_t>::max()), ranges(ranges),
32c484c7ddSChia-hung Duan startRanges(ranges), allocator(allocator) {
33c484c7ddSChia-hung Duan if (parent != this)
34c484c7ddSChia-hung Duan if (failed(initialize(parent->getModule(), parent->getRegion())))
35c484c7ddSChia-hung Duan llvm_unreachable("unexpected initialization failure");
36c484c7ddSChia-hung Duan }
37c484c7ddSChia-hung Duan
initialize(ModuleOp parentModule,Region & targetRegion)38c484c7ddSChia-hung Duan LogicalResult ReductionNode::initialize(ModuleOp parentModule,
39c484c7ddSChia-hung Duan Region &targetRegion) {
40c484c7ddSChia-hung Duan // Use the mapper help us find the corresponding region after module clone.
414d67b278SJeff Niu IRMapping mapper;
42c484c7ddSChia-hung Duan module = cast<ModuleOp>(parentModule->clone(mapper));
43c484c7ddSChia-hung Duan // Use the first block of targetRegion to locate the cloned region.
44c484c7ddSChia-hung Duan Block *block = mapper.lookup(&*targetRegion.begin());
45c484c7ddSChia-hung Duan region = block->getParent();
46c484c7ddSChia-hung Duan return success();
47c484c7ddSChia-hung Duan }
48c484c7ddSChia-hung Duan
49c484c7ddSChia-hung Duan /// If we haven't explored any variants from this node, we will create N
50c484c7ddSChia-hung Duan /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
51c484c7ddSChia-hung Duan /// max element in `ranges` and create 2 new variants for each call.
generateNewVariants()52c484c7ddSChia-hung Duan ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
53c484c7ddSChia-hung Duan int oldNumVariant = getVariants().size();
54c484c7ddSChia-hung Duan
554f415216SMehdi Amini auto createNewNode = [this](const std::vector<Range> &ranges) {
56337c937dSMehdi Amini return new (allocator.Allocate()) ReductionNode(this, ranges, allocator);
57c484c7ddSChia-hung Duan };
58c484c7ddSChia-hung Duan
59c484c7ddSChia-hung Duan // If we haven't created new variant, then we can create varients by removing
60c484c7ddSChia-hung Duan // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
61c484c7ddSChia-hung Duan // produce variants with range {{1, 3}} and {{4, 9}}.
625a1f6077SMehdi Amini if (variants.empty() && getRanges().size() > 1) {
63c484c7ddSChia-hung Duan for (const Range &range : getRanges()) {
64c484c7ddSChia-hung Duan std::vector<Range> subRanges = getRanges();
65f9306f6dSKazu Hirata llvm::erase(subRanges, range);
6660d13b85SMehdi Amini variants.push_back(createNewNode(subRanges));
67c484c7ddSChia-hung Duan }
68c484c7ddSChia-hung Duan
69c484c7ddSChia-hung Duan return getVariants().drop_front(oldNumVariant);
70c484c7ddSChia-hung Duan }
71c484c7ddSChia-hung Duan
72c484c7ddSChia-hung Duan // At here, we have created the type of variants mentioned above. We would
73c484c7ddSChia-hung Duan // like to split the max range into 2 to create 2 new variants. Continue on
74c484c7ddSChia-hung Duan // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
75c484c7ddSChia-hung Duan // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
76c484c7ddSChia-hung Duan // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
77*fab2bb8bSJustin Lebar auto maxElement =
78*fab2bb8bSJustin Lebar llvm::max_element(ranges, [](const Range &lhs, const Range &rhs) {
79c484c7ddSChia-hung Duan return (lhs.second - lhs.first) > (rhs.second - rhs.first);
80c484c7ddSChia-hung Duan });
81c484c7ddSChia-hung Duan
82c484c7ddSChia-hung Duan // The length of range is less than 1, we can't split it to create new
83c484c7ddSChia-hung Duan // variant.
84c484c7ddSChia-hung Duan if (maxElement->second - maxElement->first <= 1)
85c484c7ddSChia-hung Duan return {};
86c484c7ddSChia-hung Duan
87c484c7ddSChia-hung Duan Range maxRange = *maxElement;
88c484c7ddSChia-hung Duan std::vector<Range> subRanges = getRanges();
89c484c7ddSChia-hung Duan auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
90c484c7ddSChia-hung Duan int half = (maxRange.first + maxRange.second) / 2;
91c484c7ddSChia-hung Duan *subRangesIter = std::make_pair(maxRange.first, half);
92c484c7ddSChia-hung Duan variants.push_back(createNewNode(subRanges));
93c484c7ddSChia-hung Duan *subRangesIter = std::make_pair(half, maxRange.second);
9460d13b85SMehdi Amini variants.push_back(createNewNode(subRanges));
95c484c7ddSChia-hung Duan
96c484c7ddSChia-hung Duan auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
97c484c7ddSChia-hung Duan it = ranges.insert(it, std::make_pair(maxRange.first, half));
98c484c7ddSChia-hung Duan // Remove the range that has been split.
99c484c7ddSChia-hung Duan ranges.erase(it + 2);
100c484c7ddSChia-hung Duan
101c484c7ddSChia-hung Duan return getVariants().drop_front(oldNumVariant);
102c484c7ddSChia-hung Duan }
103c484c7ddSChia-hung Duan
update(std::pair<Tester::Interestingness,size_t> result)104c484c7ddSChia-hung Duan void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
105c484c7ddSChia-hung Duan std::tie(interesting, size) = result;
106c484c7ddSChia-hung Duan // After applying reduction, the number of operation in the region may have
107c484c7ddSChia-hung Duan // changed. Non-interesting case won't be explored thus it's safe to keep it
108c484c7ddSChia-hung Duan // in a stale status.
109c484c7ddSChia-hung Duan if (interesting == Tester::Interestingness::True) {
110c484c7ddSChia-hung Duan // This module may has been updated. Reset the range.
111c484c7ddSChia-hung Duan ranges.clear();
112e5639b3fSMehdi Amini ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
113ba913b8dSChia-hung Duan } else {
114ba913b8dSChia-hung Duan // Release the uninteresting module to save some memory.
115ba913b8dSChia-hung Duan module.release()->erase();
116c484c7ddSChia-hung Duan }
117c484c7ddSChia-hung Duan }
118c484c7ddSChia-hung Duan
119c484c7ddSChia-hung Duan ArrayRef<ReductionNode *>
getNeighbors(ReductionNode * node)120c484c7ddSChia-hung Duan ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
121c484c7ddSChia-hung Duan // Single Path: Traverses the smallest successful variant at each level until
122c484c7ddSChia-hung Duan // no new successful variants can be created at that level.
123c484c7ddSChia-hung Duan ArrayRef<ReductionNode *> variantsFromParent =
124c484c7ddSChia-hung Duan node->getParent()->getVariants();
125c484c7ddSChia-hung Duan
126c484c7ddSChia-hung Duan // The parent node created several variants and they may be waiting for
127c484c7ddSChia-hung Duan // examing interestingness. In Single Path approach, we will select the
128c484c7ddSChia-hung Duan // smallest variant to continue our exploration. Thus we should wait until the
129c484c7ddSChia-hung Duan // last variant to be examed then do the following traversal decision.
130c484c7ddSChia-hung Duan if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
131c484c7ddSChia-hung Duan return node->isInteresting() != Tester::Interestingness::Untested;
132c484c7ddSChia-hung Duan })) {
133c484c7ddSChia-hung Duan return {};
134c484c7ddSChia-hung Duan }
135c484c7ddSChia-hung Duan
136c484c7ddSChia-hung Duan ReductionNode *smallest = nullptr;
137c484c7ddSChia-hung Duan for (ReductionNode *node : variantsFromParent) {
138c484c7ddSChia-hung Duan if (node->isInteresting() != Tester::Interestingness::True)
139c484c7ddSChia-hung Duan continue;
140c484c7ddSChia-hung Duan if (smallest == nullptr || node->getSize() < smallest->getSize())
141c484c7ddSChia-hung Duan smallest = node;
142c484c7ddSChia-hung Duan }
143c484c7ddSChia-hung Duan
144c484c7ddSChia-hung Duan if (smallest != nullptr &&
145c484c7ddSChia-hung Duan smallest->getSize() < node->getParent()->getSize()) {
146c484c7ddSChia-hung Duan // We got a smallest one, keep traversing from this node.
147c484c7ddSChia-hung Duan node = smallest;
148c484c7ddSChia-hung Duan } else {
149c484c7ddSChia-hung Duan // None of these variants is interesting, let the parent node to generate
150c484c7ddSChia-hung Duan // more variants.
151c484c7ddSChia-hung Duan node = node->getParent();
152c484c7ddSChia-hung Duan }
153c484c7ddSChia-hung Duan
154c484c7ddSChia-hung Duan return node->generateNewVariants();
155c484c7ddSChia-hung Duan }
156