1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 #include "mlir/Interfaces/ViewLikeInterface.h" 14 #include "llvm/ADT/SetOperations.h" 15 16 using namespace mlir; 17 18 /// Constructs a new alias analysis using the op provided. 19 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } 20 21 /// Find all immediate and indirect dependent buffers this value could 22 /// potentially have. Note that the resulting set will also contain the value 23 /// provided as it is a dependent alias of itself. 24 BufferViewFlowAnalysis::ValueSetT 25 BufferViewFlowAnalysis::resolve(Value rootValue) const { 26 ValueSetT result; 27 SmallVector<Value, 8> queue; 28 queue.push_back(rootValue); 29 while (!queue.empty()) { 30 Value currentValue = queue.pop_back_val(); 31 if (result.insert(currentValue).second) { 32 auto it = dependencies.find(currentValue); 33 if (it != dependencies.end()) { 34 for (Value aliasValue : it->second) 35 queue.push_back(aliasValue); 36 } 37 } 38 } 39 return result; 40 } 41 42 /// Removes the given values from all alias sets. 43 void BufferViewFlowAnalysis::remove(const SmallPtrSetImpl<Value> &aliasValues) { 44 for (auto &entry : dependencies) 45 llvm::set_subtract(entry.second, aliasValues); 46 } 47 48 /// This function constructs a mapping from values to its immediate 49 /// dependencies. It iterates over all blocks, gets their predecessors, 50 /// determines the values that will be passed to the corresponding block 51 /// arguments and inserts them into the underlying map. Furthermore, it wires 52 /// successor regions and branch-like return operations from nested regions. 53 void BufferViewFlowAnalysis::build(Operation *op) { 54 // Registers all dependencies of the given values. 55 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { 56 for (auto [value, dep] : llvm::zip(values, dependencies)) 57 this->dependencies[value].insert(dep); 58 }; 59 60 // Add additional dependencies created by view changes to the alias list. 61 op->walk([&](ViewLikeOpInterface viewInterface) { 62 dependencies[viewInterface.getViewSource()].insert( 63 viewInterface->getResult(0)); 64 }); 65 66 // Query all branch interfaces to link block argument dependencies. 67 op->walk([&](BranchOpInterface branchInterface) { 68 Block *parentBlock = branchInterface->getBlock(); 69 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); 70 it != e; ++it) { 71 // Query the branch op interface to get the successor operands. 72 auto successorOperands = 73 branchInterface.getSuccessorOperands(it.getIndex()); 74 // Build the actual mapping of values to their immediate dependencies. 75 registerDependencies(successorOperands.getForwardedOperands(), 76 (*it)->getArguments().drop_front( 77 successorOperands.getProducedOperandCount())); 78 } 79 }); 80 81 // Query the RegionBranchOpInterface to find potential successor regions. 82 op->walk([&](RegionBranchOpInterface regionInterface) { 83 // Extract all entry regions and wire all initial entry successor inputs. 84 SmallVector<RegionSuccessor, 2> entrySuccessors; 85 regionInterface.getSuccessorRegions(/*index=*/llvm::None, entrySuccessors); 86 for (RegionSuccessor &entrySuccessor : entrySuccessors) { 87 // Wire the entry region's successor arguments with the initial 88 // successor inputs. 89 assert(entrySuccessor.getSuccessor() && 90 "Invalid entry region without an attached successor region"); 91 registerDependencies( 92 regionInterface.getSuccessorEntryOperands( 93 entrySuccessor.getSuccessor()->getRegionNumber()), 94 entrySuccessor.getSuccessorInputs()); 95 } 96 97 // Wire flow between regions and from region exits. 98 for (Region ®ion : regionInterface->getRegions()) { 99 // Iterate over all successor region entries that are reachable from the 100 // current region. 101 SmallVector<RegionSuccessor, 2> successorRegions; 102 regionInterface.getSuccessorRegions(region.getRegionNumber(), 103 successorRegions); 104 for (RegionSuccessor &successorRegion : successorRegions) { 105 // Determine the current region index (if any). 106 Optional<unsigned> regionIndex; 107 Region *regionSuccessor = successorRegion.getSuccessor(); 108 if (regionSuccessor) 109 regionIndex = regionSuccessor->getRegionNumber(); 110 // Iterate over all immediate terminator operations and wire the 111 // successor inputs with the successor operands of each terminator. 112 for (Block &block : region) { 113 auto successorOperands = getRegionBranchSuccessorOperands( 114 block.getTerminator(), regionIndex); 115 if (successorOperands) { 116 registerDependencies(*successorOperands, 117 successorRegion.getSuccessorInputs()); 118 } 119 } 120 } 121 } 122 }); 123 124 // TODO: This should be an interface. 125 op->walk([&](arith::SelectOp selectOp) { 126 registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()}); 127 registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()}); 128 }); 129 } 130