1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 #include "mlir/Interfaces/ViewLikeInterface.h" 14 #include "llvm/ADT/SetOperations.h" 15 #include "llvm/ADT/SetVector.h" 16 17 using namespace mlir; 18 19 /// Constructs a new alias analysis using the op provided. 20 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } 21 22 /// Find all immediate and indirect dependent buffers this value could 23 /// potentially have. Note that the resulting set will also contain the value 24 /// provided as it is a dependent alias of itself. 25 BufferViewFlowAnalysis::ValueSetT 26 BufferViewFlowAnalysis::resolve(Value rootValue) const { 27 ValueSetT result; 28 SmallVector<Value, 8> queue; 29 queue.push_back(rootValue); 30 while (!queue.empty()) { 31 Value currentValue = queue.pop_back_val(); 32 if (result.insert(currentValue).second) { 33 auto it = dependencies.find(currentValue); 34 if (it != dependencies.end()) { 35 for (Value aliasValue : it->second) 36 queue.push_back(aliasValue); 37 } 38 } 39 } 40 return result; 41 } 42 43 /// Removes the given values from all alias sets. 44 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) { 45 for (auto &entry : dependencies) 46 llvm::set_subtract(entry.second, aliasValues); 47 } 48 49 /// This function constructs a mapping from values to its immediate 50 /// dependencies. It iterates over all blocks, gets their predecessors, 51 /// determines the values that will be passed to the corresponding block 52 /// arguments and inserts them into the underlying map. Furthermore, it wires 53 /// successor regions and branch-like return operations from nested regions. 54 void BufferViewFlowAnalysis::build(Operation *op) { 55 // Registers all dependencies of the given values. 56 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { 57 for (auto [value, dep] : llvm::zip(values, dependencies)) 58 this->dependencies[value].insert(dep); 59 }; 60 61 // Add additional dependencies created by view changes to the alias list. 62 op->walk([&](ViewLikeOpInterface viewInterface) { 63 dependencies[viewInterface.getViewSource()].insert( 64 viewInterface->getResult(0)); 65 }); 66 67 // Query all branch interfaces to link block argument dependencies. 68 op->walk([&](BranchOpInterface branchInterface) { 69 Block *parentBlock = branchInterface->getBlock(); 70 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); 71 it != e; ++it) { 72 // Query the branch op interface to get the successor operands. 73 auto successorOperands = 74 branchInterface.getSuccessorOperands(it.getIndex()); 75 // Build the actual mapping of values to their immediate dependencies. 76 registerDependencies(successorOperands.getForwardedOperands(), 77 (*it)->getArguments().drop_front( 78 successorOperands.getProducedOperandCount())); 79 } 80 }); 81 82 // Query the RegionBranchOpInterface to find potential successor regions. 83 op->walk([&](RegionBranchOpInterface regionInterface) { 84 // Extract all entry regions and wire all initial entry successor inputs. 85 SmallVector<RegionSuccessor, 2> entrySuccessors; 86 regionInterface.getSuccessorRegions(/*index=*/std::nullopt, 87 entrySuccessors); 88 for (RegionSuccessor &entrySuccessor : entrySuccessors) { 89 // Wire the entry region's successor arguments with the initial 90 // successor inputs. 91 assert(entrySuccessor.getSuccessor() && 92 "Invalid entry region without an attached successor region"); 93 registerDependencies( 94 regionInterface.getSuccessorEntryOperands( 95 entrySuccessor.getSuccessor()->getRegionNumber()), 96 entrySuccessor.getSuccessorInputs()); 97 } 98 99 // Wire flow between regions and from region exits. 100 for (Region ®ion : regionInterface->getRegions()) { 101 // Iterate over all successor region entries that are reachable from the 102 // current region. 103 SmallVector<RegionSuccessor, 2> successorRegions; 104 regionInterface.getSuccessorRegions(region.getRegionNumber(), 105 successorRegions); 106 for (RegionSuccessor &successorRegion : successorRegions) { 107 // Determine the current region index (if any). 108 std::optional<unsigned> regionIndex; 109 Region *regionSuccessor = successorRegion.getSuccessor(); 110 if (regionSuccessor) 111 regionIndex = regionSuccessor->getRegionNumber(); 112 // Iterate over all immediate terminator operations and wire the 113 // successor inputs with the successor operands of each terminator. 114 for (Block &block : region) { 115 auto successorOperands = getRegionBranchSuccessorOperands( 116 block.getTerminator(), regionIndex); 117 if (successorOperands) { 118 registerDependencies(*successorOperands, 119 successorRegion.getSuccessorInputs()); 120 } 121 } 122 } 123 } 124 }); 125 126 // TODO: This should be an interface. 127 op->walk([&](arith::SelectOp selectOp) { 128 registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()}); 129 registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()}); 130 }); 131 } 132