1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 10 11 #include "mlir/Interfaces/ControlFlowInterfaces.h" 12 #include "mlir/Interfaces/ViewLikeInterface.h" 13 #include "llvm/ADT/SetOperations.h" 14 #include "llvm/ADT/SetVector.h" 15 16 using namespace mlir; 17 18 /// Constructs a new alias analysis using the op provided. 19 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } 20 21 /// Find all immediate and indirect dependent buffers this value could 22 /// potentially have. Note that the resulting set will also contain the value 23 /// provided as it is a dependent alias of itself. 24 BufferViewFlowAnalysis::ValueSetT 25 BufferViewFlowAnalysis::resolve(Value rootValue) const { 26 ValueSetT result; 27 SmallVector<Value, 8> queue; 28 queue.push_back(rootValue); 29 while (!queue.empty()) { 30 Value currentValue = queue.pop_back_val(); 31 if (result.insert(currentValue).second) { 32 auto it = dependencies.find(currentValue); 33 if (it != dependencies.end()) { 34 for (Value aliasValue : it->second) 35 queue.push_back(aliasValue); 36 } 37 } 38 } 39 return result; 40 } 41 42 /// Removes the given values from all alias sets. 43 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) { 44 for (auto &entry : dependencies) 45 llvm::set_subtract(entry.second, aliasValues); 46 } 47 48 /// This function constructs a mapping from values to its immediate 49 /// dependencies. It iterates over all blocks, gets their predecessors, 50 /// determines the values that will be passed to the corresponding block 51 /// arguments and inserts them into the underlying map. Furthermore, it wires 52 /// successor regions and branch-like return operations from nested regions. 53 void BufferViewFlowAnalysis::build(Operation *op) { 54 // Registers all dependencies of the given values. 55 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { 56 for (auto [value, dep] : llvm::zip(values, dependencies)) 57 this->dependencies[value].insert(dep); 58 }; 59 60 op->walk([&](Operation *op) { 61 // TODO: We should have an op interface instead of a hard-coded list of 62 // interfaces/ops. 63 64 // Add additional dependencies created by view changes to the alias list. 65 if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { 66 dependencies[viewInterface.getViewSource()].insert( 67 viewInterface->getResult(0)); 68 return WalkResult::advance(); 69 } 70 71 if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) { 72 // Query all branch interfaces to link block argument dependencies. 73 Block *parentBlock = branchInterface->getBlock(); 74 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); 75 it != e; ++it) { 76 // Query the branch op interface to get the successor operands. 77 auto successorOperands = 78 branchInterface.getSuccessorOperands(it.getIndex()); 79 // Build the actual mapping of values to their immediate dependencies. 80 registerDependencies(successorOperands.getForwardedOperands(), 81 (*it)->getArguments().drop_front( 82 successorOperands.getProducedOperandCount())); 83 } 84 return WalkResult::advance(); 85 } 86 87 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) { 88 // Query the RegionBranchOpInterface to find potential successor regions. 89 // Extract all entry regions and wire all initial entry successor inputs. 90 SmallVector<RegionSuccessor, 2> entrySuccessors; 91 regionInterface.getSuccessorRegions(/*index=*/std::nullopt, 92 entrySuccessors); 93 for (RegionSuccessor &entrySuccessor : entrySuccessors) { 94 // Wire the entry region's successor arguments with the initial 95 // successor inputs. 96 assert(entrySuccessor.getSuccessor() && 97 "Invalid entry region without an attached successor region"); 98 registerDependencies( 99 regionInterface.getSuccessorEntryOperands( 100 entrySuccessor.getSuccessor()->getRegionNumber()), 101 entrySuccessor.getSuccessorInputs()); 102 } 103 104 // Wire flow between regions and from region exits. 105 for (Region ®ion : regionInterface->getRegions()) { 106 // Iterate over all successor region entries that are reachable from the 107 // current region. 108 SmallVector<RegionSuccessor, 2> successorRegions; 109 regionInterface.getSuccessorRegions(region.getRegionNumber(), 110 successorRegions); 111 for (RegionSuccessor &successorRegion : successorRegions) { 112 // Determine the current region index (if any). 113 std::optional<unsigned> regionIndex; 114 Region *regionSuccessor = successorRegion.getSuccessor(); 115 if (regionSuccessor) 116 regionIndex = regionSuccessor->getRegionNumber(); 117 // Iterate over all immediate terminator operations and wire the 118 // successor inputs with the successor operands of each terminator. 119 for (Block &block : region) { 120 auto successorOperands = getRegionBranchSuccessorOperands( 121 block.getTerminator(), regionIndex); 122 if (successorOperands) { 123 registerDependencies(*successorOperands, 124 successorRegion.getSuccessorInputs()); 125 } 126 } 127 } 128 } 129 130 return WalkResult::advance(); 131 } 132 133 // Unknown op: Assume that all operands alias with all results. 134 for (Value operand : op->getOperands()) { 135 if (!isa<BaseMemRefType>(operand.getType())) 136 continue; 137 for (Value result : op->getResults()) { 138 if (!isa<BaseMemRefType>(result.getType())) 139 continue; 140 registerDependencies({operand}, {result}); 141 } 142 } 143 return WalkResult::advance(); 144 }); 145 } 146