1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" 12 #include "mlir/Interfaces/CallInterfaces.h" 13 #include "mlir/Interfaces/ControlFlowInterfaces.h" 14 #include "mlir/Interfaces/FunctionInterfaces.h" 15 #include "mlir/Interfaces/ViewLikeInterface.h" 16 #include "llvm/ADT/SetOperations.h" 17 #include "llvm/ADT/SetVector.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 22 /// Constructs a new alias analysis using the op provided. 23 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } 24 25 /// Find all immediate and indirect dependent buffers this value could 26 /// potentially have. Note that the resulting set will also contain the value 27 /// provided as it is a dependent alias of itself. 28 BufferViewFlowAnalysis::ValueSetT 29 BufferViewFlowAnalysis::resolve(Value rootValue) const { 30 ValueSetT result; 31 SmallVector<Value, 8> queue; 32 queue.push_back(rootValue); 33 while (!queue.empty()) { 34 Value currentValue = queue.pop_back_val(); 35 if (result.insert(currentValue).second) { 36 auto it = dependencies.find(currentValue); 37 if (it != dependencies.end()) { 38 for (Value aliasValue : it->second) 39 queue.push_back(aliasValue); 40 } 41 } 42 } 43 return result; 44 } 45 46 /// Removes the given values from all alias sets. 47 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) { 48 for (auto &entry : dependencies) 49 llvm::set_subtract(entry.second, aliasValues); 50 } 51 52 void BufferViewFlowAnalysis::rename(Value from, Value to) { 53 dependencies[to] = dependencies[from]; 54 dependencies.erase(from); 55 56 for (auto &[_, value] : dependencies) { 57 if (value.contains(from)) { 58 value.insert(to); 59 value.erase(from); 60 } 61 } 62 } 63 64 /// This function constructs a mapping from values to its immediate 65 /// dependencies. It iterates over all blocks, gets their predecessors, 66 /// determines the values that will be passed to the corresponding block 67 /// arguments and inserts them into the underlying map. Furthermore, it wires 68 /// successor regions and branch-like return operations from nested regions. 69 void BufferViewFlowAnalysis::build(Operation *op) { 70 // Registers all dependencies of the given values. 71 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { 72 for (auto [value, dep] : llvm::zip_equal(values, dependencies)) 73 this->dependencies[value].insert(dep); 74 }; 75 76 // Mark all buffer results and buffer region entry block arguments of the 77 // given op as terminals. 78 auto populateTerminalValues = [&](Operation *op) { 79 for (Value v : op->getResults()) 80 if (isa<BaseMemRefType>(v.getType())) 81 this->terminals.insert(v); 82 for (Region &r : op->getRegions()) 83 for (BlockArgument v : r.getArguments()) 84 if (isa<BaseMemRefType>(v.getType())) 85 this->terminals.insert(v); 86 }; 87 88 op->walk([&](Operation *op) { 89 // Query BufferViewFlowOpInterface. If the op does not implement that 90 // interface, try to infer the dependencies from other interfaces that the 91 // op may implement. 92 if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) { 93 bufferViewFlowOp.populateDependencies(registerDependencies); 94 for (Value v : op->getResults()) 95 if (isa<BaseMemRefType>(v.getType()) && 96 bufferViewFlowOp.mayBeTerminalBuffer(v)) 97 this->terminals.insert(v); 98 for (Region &r : op->getRegions()) 99 for (BlockArgument v : r.getArguments()) 100 if (isa<BaseMemRefType>(v.getType()) && 101 bufferViewFlowOp.mayBeTerminalBuffer(v)) 102 this->terminals.insert(v); 103 return WalkResult::advance(); 104 } 105 106 // Add additional dependencies created by view changes to the alias list. 107 if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { 108 registerDependencies(viewInterface.getViewSource(), 109 viewInterface->getResult(0)); 110 return WalkResult::advance(); 111 } 112 113 if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) { 114 // Query all branch interfaces to link block argument dependencies. 115 Block *parentBlock = branchInterface->getBlock(); 116 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); 117 it != e; ++it) { 118 // Query the branch op interface to get the successor operands. 119 auto successorOperands = 120 branchInterface.getSuccessorOperands(it.getIndex()); 121 // Build the actual mapping of values to their immediate dependencies. 122 registerDependencies(successorOperands.getForwardedOperands(), 123 (*it)->getArguments().drop_front( 124 successorOperands.getProducedOperandCount())); 125 } 126 return WalkResult::advance(); 127 } 128 129 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) { 130 // Query the RegionBranchOpInterface to find potential successor regions. 131 // Extract all entry regions and wire all initial entry successor inputs. 132 SmallVector<RegionSuccessor, 2> entrySuccessors; 133 regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), 134 entrySuccessors); 135 for (RegionSuccessor &entrySuccessor : entrySuccessors) { 136 // Wire the entry region's successor arguments with the initial 137 // successor inputs. 138 registerDependencies( 139 regionInterface.getEntrySuccessorOperands(entrySuccessor), 140 entrySuccessor.getSuccessorInputs()); 141 } 142 143 // Wire flow between regions and from region exits. 144 for (Region ®ion : regionInterface->getRegions()) { 145 // Iterate over all successor region entries that are reachable from the 146 // current region. 147 SmallVector<RegionSuccessor, 2> successorRegions; 148 regionInterface.getSuccessorRegions(region, successorRegions); 149 for (RegionSuccessor &successorRegion : successorRegions) { 150 // Iterate over all immediate terminator operations and wire the 151 // successor inputs with the successor operands of each terminator. 152 for (Block &block : region) 153 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>( 154 block.getTerminator())) 155 registerDependencies( 156 terminator.getSuccessorOperands(successorRegion), 157 successorRegion.getSuccessorInputs()); 158 } 159 } 160 161 return WalkResult::advance(); 162 } 163 164 // Region terminators are handled together with RegionBranchOpInterface. 165 if (isa<RegionBranchTerminatorOpInterface>(op)) 166 return WalkResult::advance(); 167 168 if (isa<CallOpInterface>(op)) { 169 // This is an intra-function analysis. We have no information about other 170 // functions. Conservatively assume that each operand may alias with each 171 // result. Also mark the results are terminals because the function could 172 // return newly allocated buffers. 173 populateTerminalValues(op); 174 for (Value operand : op->getOperands()) 175 for (Value result : op->getResults()) 176 registerDependencies({operand}, {result}); 177 return WalkResult::advance(); 178 } 179 180 // We have no information about unknown ops. 181 populateTerminalValues(op); 182 183 return WalkResult::advance(); 184 }); 185 } 186 187 bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const { 188 assert(isa<BaseMemRefType>(value.getType()) && "expected memref"); 189 return terminals.contains(value); 190 } 191