1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the reduction nodes which are used to track of the 10 // metadata for a specific generated variant within a reduction pass and are the 11 // building blocks of the reduction tree structure. A reduction tree is used to 12 // keep track of the different generated variants throughout a reduction pass in 13 // the MLIR Reduce tool. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Reducer/ReductionNode.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 #include <algorithm> 22 #include <limits> 23 24 using namespace mlir; 25 26 ReductionNode::ReductionNode( 27 ReductionNode *parentNode, const std::vector<Range> &ranges, 28 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator) 29 /// Root node will have the parent pointer point to themselves. 30 : parent(parentNode == nullptr ? this : parentNode), 31 size(std::numeric_limits<size_t>::max()), ranges(ranges), 32 startRanges(ranges), allocator(allocator) { 33 if (parent != this) 34 if (failed(initialize(parent->getModule(), parent->getRegion()))) 35 llvm_unreachable("unexpected initialization failure"); 36 } 37 38 LogicalResult ReductionNode::initialize(ModuleOp parentModule, 39 Region &targetRegion) { 40 // Use the mapper help us find the corresponding region after module clone. 41 BlockAndValueMapping mapper; 42 module = cast<ModuleOp>(parentModule->clone(mapper)); 43 // Use the first block of targetRegion to locate the cloned region. 44 Block *block = mapper.lookup(&*targetRegion.begin()); 45 region = block->getParent(); 46 return success(); 47 } 48 49 /// If we haven't explored any variants from this node, we will create N 50 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the 51 /// max element in `ranges` and create 2 new variants for each call. 52 ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() { 53 int oldNumVariant = getVariants().size(); 54 55 auto createNewNode = [this](std::vector<Range> ranges) { 56 return new (allocator.Allocate()) 57 ReductionNode(this, std::move(ranges), allocator); 58 }; 59 60 // If we haven't created new variant, then we can create varients by removing 61 // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can 62 // produce variants with range {{1, 3}} and {{4, 9}}. 63 if (variants.empty() && getRanges().size() > 1) { 64 for (const Range &range : getRanges()) { 65 std::vector<Range> subRanges = getRanges(); 66 llvm::erase_value(subRanges, range); 67 variants.push_back(createNewNode(std::move(subRanges))); 68 } 69 70 return getVariants().drop_front(oldNumVariant); 71 } 72 73 // At here, we have created the type of variants mentioned above. We would 74 // like to split the max range into 2 to create 2 new variants. Continue on 75 // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and 76 // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The 77 // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. 78 auto maxElement = std::max_element( 79 ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) { 80 return (lhs.second - lhs.first) > (rhs.second - rhs.first); 81 }); 82 83 // The length of range is less than 1, we can't split it to create new 84 // variant. 85 if (maxElement->second - maxElement->first <= 1) 86 return {}; 87 88 Range maxRange = *maxElement; 89 std::vector<Range> subRanges = getRanges(); 90 auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); 91 int half = (maxRange.first + maxRange.second) / 2; 92 *subRangesIter = std::make_pair(maxRange.first, half); 93 variants.push_back(createNewNode(subRanges)); 94 *subRangesIter = std::make_pair(half, maxRange.second); 95 variants.push_back(createNewNode(std::move(subRanges))); 96 97 auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second)); 98 it = ranges.insert(it, std::make_pair(maxRange.first, half)); 99 // Remove the range that has been split. 100 ranges.erase(it + 2); 101 102 return getVariants().drop_front(oldNumVariant); 103 } 104 105 void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) { 106 std::tie(interesting, size) = result; 107 // After applying reduction, the number of operation in the region may have 108 // changed. Non-interesting case won't be explored thus it's safe to keep it 109 // in a stale status. 110 if (interesting == Tester::Interestingness::True) { 111 // This module may has been updated. Reset the range. 112 ranges.clear(); 113 ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end())); 114 } else { 115 // Release the uninteresting module to save some memory. 116 module.release()->erase(); 117 } 118 } 119 120 ArrayRef<ReductionNode *> 121 ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) { 122 // Single Path: Traverses the smallest successful variant at each level until 123 // no new successful variants can be created at that level. 124 ArrayRef<ReductionNode *> variantsFromParent = 125 node->getParent()->getVariants(); 126 127 // The parent node created several variants and they may be waiting for 128 // examing interestingness. In Single Path approach, we will select the 129 // smallest variant to continue our exploration. Thus we should wait until the 130 // last variant to be examed then do the following traversal decision. 131 if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) { 132 return node->isInteresting() != Tester::Interestingness::Untested; 133 })) { 134 return {}; 135 } 136 137 ReductionNode *smallest = nullptr; 138 for (ReductionNode *node : variantsFromParent) { 139 if (node->isInteresting() != Tester::Interestingness::True) 140 continue; 141 if (smallest == nullptr || node->getSize() < smallest->getSize()) 142 smallest = node; 143 } 144 145 if (smallest != nullptr && 146 smallest->getSize() < node->getParent()->getSize()) { 147 // We got a smallest one, keep traversing from this node. 148 node = smallest; 149 } else { 150 // None of these variants is interesting, let the parent node to generate 151 // more variants. 152 node = node->getParent(); 153 } 154 155 return node->generateNewVariants(); 156 } 157