1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 10 11 #include "mlir/Interfaces/ControlFlowInterfaces.h" 12 #include "mlir/Interfaces/ViewLikeInterface.h" 13 #include "llvm/ADT/SetOperations.h" 14 #include "llvm/ADT/SetVector.h" 15 16 using namespace mlir; 17 18 /// Constructs a new alias analysis using the op provided. 19 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } 20 21 /// Find all immediate and indirect dependent buffers this value could 22 /// potentially have. Note that the resulting set will also contain the value 23 /// provided as it is a dependent alias of itself. 24 BufferViewFlowAnalysis::ValueSetT 25 BufferViewFlowAnalysis::resolve(Value rootValue) const { 26 ValueSetT result; 27 SmallVector<Value, 8> queue; 28 queue.push_back(rootValue); 29 while (!queue.empty()) { 30 Value currentValue = queue.pop_back_val(); 31 if (result.insert(currentValue).second) { 32 auto it = dependencies.find(currentValue); 33 if (it != dependencies.end()) { 34 for (Value aliasValue : it->second) 35 queue.push_back(aliasValue); 36 } 37 } 38 } 39 return result; 40 } 41 42 /// Removes the given values from all alias sets. 43 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) { 44 for (auto &entry : dependencies) 45 llvm::set_subtract(entry.second, aliasValues); 46 } 47 48 void BufferViewFlowAnalysis::rename(Value from, Value to) { 49 dependencies[to] = dependencies[from]; 50 dependencies.erase(from); 51 52 for (auto &[_, value] : dependencies) { 53 if (value.contains(from)) { 54 value.insert(to); 55 value.erase(from); 56 } 57 } 58 } 59 60 /// This function constructs a mapping from values to its immediate 61 /// dependencies. It iterates over all blocks, gets their predecessors, 62 /// determines the values that will be passed to the corresponding block 63 /// arguments and inserts them into the underlying map. Furthermore, it wires 64 /// successor regions and branch-like return operations from nested regions. 65 void BufferViewFlowAnalysis::build(Operation *op) { 66 // Registers all dependencies of the given values. 67 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { 68 for (auto [value, dep] : llvm::zip(values, dependencies)) 69 this->dependencies[value].insert(dep); 70 }; 71 72 op->walk([&](Operation *op) { 73 // TODO: We should have an op interface instead of a hard-coded list of 74 // interfaces/ops. 75 76 // Add additional dependencies created by view changes to the alias list. 77 if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { 78 dependencies[viewInterface.getViewSource()].insert( 79 viewInterface->getResult(0)); 80 return WalkResult::advance(); 81 } 82 83 if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) { 84 // Query all branch interfaces to link block argument dependencies. 85 Block *parentBlock = branchInterface->getBlock(); 86 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); 87 it != e; ++it) { 88 // Query the branch op interface to get the successor operands. 89 auto successorOperands = 90 branchInterface.getSuccessorOperands(it.getIndex()); 91 // Build the actual mapping of values to their immediate dependencies. 92 registerDependencies(successorOperands.getForwardedOperands(), 93 (*it)->getArguments().drop_front( 94 successorOperands.getProducedOperandCount())); 95 } 96 return WalkResult::advance(); 97 } 98 99 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) { 100 // Query the RegionBranchOpInterface to find potential successor regions. 101 // Extract all entry regions and wire all initial entry successor inputs. 102 SmallVector<RegionSuccessor, 2> entrySuccessors; 103 regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), 104 entrySuccessors); 105 for (RegionSuccessor &entrySuccessor : entrySuccessors) { 106 // Wire the entry region's successor arguments with the initial 107 // successor inputs. 108 registerDependencies( 109 regionInterface.getEntrySuccessorOperands(entrySuccessor), 110 entrySuccessor.getSuccessorInputs()); 111 } 112 113 // Wire flow between regions and from region exits. 114 for (Region ®ion : regionInterface->getRegions()) { 115 // Iterate over all successor region entries that are reachable from the 116 // current region. 117 SmallVector<RegionSuccessor, 2> successorRegions; 118 regionInterface.getSuccessorRegions(region, successorRegions); 119 for (RegionSuccessor &successorRegion : successorRegions) { 120 // Iterate over all immediate terminator operations and wire the 121 // successor inputs with the successor operands of each terminator. 122 for (Block &block : region) 123 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>( 124 block.getTerminator())) 125 registerDependencies( 126 terminator.getSuccessorOperands(successorRegion), 127 successorRegion.getSuccessorInputs()); 128 } 129 } 130 131 return WalkResult::advance(); 132 } 133 134 // Unknown op: Assume that all operands alias with all results. 135 for (Value operand : op->getOperands()) { 136 if (!isa<BaseMemRefType>(operand.getType())) 137 continue; 138 for (Value result : op->getResults()) { 139 if (!isa<BaseMemRefType>(result.getType())) 140 continue; 141 registerDependencies({operand}, {result}); 142 } 143 } 144 return WalkResult::advance(); 145 }); 146 } 147