123dec4a3SJohannes Reifferscheid //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// 223dec4a3SJohannes Reifferscheid // 323dec4a3SJohannes Reifferscheid // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 423dec4a3SJohannes Reifferscheid // See https://llvm.org/LICENSE.txt for license information. 523dec4a3SJohannes Reifferscheid // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 623dec4a3SJohannes Reifferscheid // 723dec4a3SJohannes Reifferscheid //===----------------------------------------------------------------------===// 823dec4a3SJohannes Reifferscheid 923dec4a3SJohannes Reifferscheid #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 1023dec4a3SJohannes Reifferscheid 1123dec4a3SJohannes Reifferscheid #include "mlir/Interfaces/ControlFlowInterfaces.h" 1223dec4a3SJohannes Reifferscheid #include "mlir/Interfaces/ViewLikeInterface.h" 1323dec4a3SJohannes Reifferscheid #include "llvm/ADT/SetOperations.h" 14e6edc1bdSUday Bondhugula #include "llvm/ADT/SetVector.h" 1523dec4a3SJohannes Reifferscheid 1623dec4a3SJohannes Reifferscheid using namespace mlir; 1723dec4a3SJohannes Reifferscheid 1823dec4a3SJohannes Reifferscheid /// Constructs a new alias analysis using the op provided. 1923dec4a3SJohannes Reifferscheid BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } 2023dec4a3SJohannes Reifferscheid 2123dec4a3SJohannes Reifferscheid /// Find all immediate and indirect dependent buffers this value could 2223dec4a3SJohannes Reifferscheid /// potentially have. Note that the resulting set will also contain the value 2323dec4a3SJohannes Reifferscheid /// provided as it is a dependent alias of itself. 2423dec4a3SJohannes Reifferscheid BufferViewFlowAnalysis::ValueSetT 2523dec4a3SJohannes Reifferscheid BufferViewFlowAnalysis::resolve(Value rootValue) const { 2623dec4a3SJohannes Reifferscheid ValueSetT result; 2723dec4a3SJohannes Reifferscheid SmallVector<Value, 8> queue; 2823dec4a3SJohannes Reifferscheid queue.push_back(rootValue); 2923dec4a3SJohannes Reifferscheid while (!queue.empty()) { 3023dec4a3SJohannes Reifferscheid Value currentValue = queue.pop_back_val(); 3123dec4a3SJohannes Reifferscheid if (result.insert(currentValue).second) { 3223dec4a3SJohannes Reifferscheid auto it = dependencies.find(currentValue); 3323dec4a3SJohannes Reifferscheid if (it != dependencies.end()) { 3423dec4a3SJohannes Reifferscheid for (Value aliasValue : it->second) 3523dec4a3SJohannes Reifferscheid queue.push_back(aliasValue); 3623dec4a3SJohannes Reifferscheid } 3723dec4a3SJohannes Reifferscheid } 3823dec4a3SJohannes Reifferscheid } 3923dec4a3SJohannes Reifferscheid return result; 4023dec4a3SJohannes Reifferscheid } 4123dec4a3SJohannes Reifferscheid 4223dec4a3SJohannes Reifferscheid /// Removes the given values from all alias sets. 43e6edc1bdSUday Bondhugula void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) { 4423dec4a3SJohannes Reifferscheid for (auto &entry : dependencies) 4523dec4a3SJohannes Reifferscheid llvm::set_subtract(entry.second, aliasValues); 4623dec4a3SJohannes Reifferscheid } 4723dec4a3SJohannes Reifferscheid 4823dec4a3SJohannes Reifferscheid /// This function constructs a mapping from values to its immediate 4923dec4a3SJohannes Reifferscheid /// dependencies. It iterates over all blocks, gets their predecessors, 5023dec4a3SJohannes Reifferscheid /// determines the values that will be passed to the corresponding block 5123dec4a3SJohannes Reifferscheid /// arguments and inserts them into the underlying map. Furthermore, it wires 5223dec4a3SJohannes Reifferscheid /// successor regions and branch-like return operations from nested regions. 5323dec4a3SJohannes Reifferscheid void BufferViewFlowAnalysis::build(Operation *op) { 5423dec4a3SJohannes Reifferscheid // Registers all dependencies of the given values. 5523dec4a3SJohannes Reifferscheid auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { 5623dec4a3SJohannes Reifferscheid for (auto [value, dep] : llvm::zip(values, dependencies)) 5723dec4a3SJohannes Reifferscheid this->dependencies[value].insert(dep); 5823dec4a3SJohannes Reifferscheid }; 5923dec4a3SJohannes Reifferscheid 60*38bef476SMatthias Springer op->walk([&](Operation *op) { 61*38bef476SMatthias Springer // TODO: We should have an op interface instead of a hard-coded list of 62*38bef476SMatthias Springer // interfaces/ops. 63*38bef476SMatthias Springer 6423dec4a3SJohannes Reifferscheid // Add additional dependencies created by view changes to the alias list. 65*38bef476SMatthias Springer if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { 6623dec4a3SJohannes Reifferscheid dependencies[viewInterface.getViewSource()].insert( 6723dec4a3SJohannes Reifferscheid viewInterface->getResult(0)); 68*38bef476SMatthias Springer return WalkResult::advance(); 69*38bef476SMatthias Springer } 7023dec4a3SJohannes Reifferscheid 71*38bef476SMatthias Springer if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) { 7223dec4a3SJohannes Reifferscheid // Query all branch interfaces to link block argument dependencies. 7323dec4a3SJohannes Reifferscheid Block *parentBlock = branchInterface->getBlock(); 7423dec4a3SJohannes Reifferscheid for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); 7523dec4a3SJohannes Reifferscheid it != e; ++it) { 7623dec4a3SJohannes Reifferscheid // Query the branch op interface to get the successor operands. 7723dec4a3SJohannes Reifferscheid auto successorOperands = 7823dec4a3SJohannes Reifferscheid branchInterface.getSuccessorOperands(it.getIndex()); 7923dec4a3SJohannes Reifferscheid // Build the actual mapping of values to their immediate dependencies. 8023dec4a3SJohannes Reifferscheid registerDependencies(successorOperands.getForwardedOperands(), 8123dec4a3SJohannes Reifferscheid (*it)->getArguments().drop_front( 8223dec4a3SJohannes Reifferscheid successorOperands.getProducedOperandCount())); 8323dec4a3SJohannes Reifferscheid } 84*38bef476SMatthias Springer return WalkResult::advance(); 85*38bef476SMatthias Springer } 8623dec4a3SJohannes Reifferscheid 87*38bef476SMatthias Springer if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) { 8823dec4a3SJohannes Reifferscheid // Query the RegionBranchOpInterface to find potential successor regions. 8923dec4a3SJohannes Reifferscheid // Extract all entry regions and wire all initial entry successor inputs. 9023dec4a3SJohannes Reifferscheid SmallVector<RegionSuccessor, 2> entrySuccessors; 911a36588eSKazu Hirata regionInterface.getSuccessorRegions(/*index=*/std::nullopt, 921a36588eSKazu Hirata entrySuccessors); 9323dec4a3SJohannes Reifferscheid for (RegionSuccessor &entrySuccessor : entrySuccessors) { 9423dec4a3SJohannes Reifferscheid // Wire the entry region's successor arguments with the initial 9523dec4a3SJohannes Reifferscheid // successor inputs. 9623dec4a3SJohannes Reifferscheid assert(entrySuccessor.getSuccessor() && 9723dec4a3SJohannes Reifferscheid "Invalid entry region without an attached successor region"); 9823dec4a3SJohannes Reifferscheid registerDependencies( 9923dec4a3SJohannes Reifferscheid regionInterface.getSuccessorEntryOperands( 10023dec4a3SJohannes Reifferscheid entrySuccessor.getSuccessor()->getRegionNumber()), 10123dec4a3SJohannes Reifferscheid entrySuccessor.getSuccessorInputs()); 10223dec4a3SJohannes Reifferscheid } 10323dec4a3SJohannes Reifferscheid 10423dec4a3SJohannes Reifferscheid // Wire flow between regions and from region exits. 10523dec4a3SJohannes Reifferscheid for (Region ®ion : regionInterface->getRegions()) { 10623dec4a3SJohannes Reifferscheid // Iterate over all successor region entries that are reachable from the 10723dec4a3SJohannes Reifferscheid // current region. 10823dec4a3SJohannes Reifferscheid SmallVector<RegionSuccessor, 2> successorRegions; 10923dec4a3SJohannes Reifferscheid regionInterface.getSuccessorRegions(region.getRegionNumber(), 11023dec4a3SJohannes Reifferscheid successorRegions); 11123dec4a3SJohannes Reifferscheid for (RegionSuccessor &successorRegion : successorRegions) { 11223dec4a3SJohannes Reifferscheid // Determine the current region index (if any). 11322426110SRamkumar Ramachandra std::optional<unsigned> regionIndex; 11423dec4a3SJohannes Reifferscheid Region *regionSuccessor = successorRegion.getSuccessor(); 11523dec4a3SJohannes Reifferscheid if (regionSuccessor) 11623dec4a3SJohannes Reifferscheid regionIndex = regionSuccessor->getRegionNumber(); 11723dec4a3SJohannes Reifferscheid // Iterate over all immediate terminator operations and wire the 11823dec4a3SJohannes Reifferscheid // successor inputs with the successor operands of each terminator. 11923dec4a3SJohannes Reifferscheid for (Block &block : region) { 12023dec4a3SJohannes Reifferscheid auto successorOperands = getRegionBranchSuccessorOperands( 12123dec4a3SJohannes Reifferscheid block.getTerminator(), regionIndex); 12223dec4a3SJohannes Reifferscheid if (successorOperands) { 12323dec4a3SJohannes Reifferscheid registerDependencies(*successorOperands, 12423dec4a3SJohannes Reifferscheid successorRegion.getSuccessorInputs()); 12523dec4a3SJohannes Reifferscheid } 12623dec4a3SJohannes Reifferscheid } 12723dec4a3SJohannes Reifferscheid } 12823dec4a3SJohannes Reifferscheid } 12923dec4a3SJohannes Reifferscheid 130*38bef476SMatthias Springer return WalkResult::advance(); 131*38bef476SMatthias Springer } 132*38bef476SMatthias Springer 133*38bef476SMatthias Springer // Unknown op: Assume that all operands alias with all results. 134*38bef476SMatthias Springer for (Value operand : op->getOperands()) { 135*38bef476SMatthias Springer if (!isa<BaseMemRefType>(operand.getType())) 136*38bef476SMatthias Springer continue; 137*38bef476SMatthias Springer for (Value result : op->getResults()) { 138*38bef476SMatthias Springer if (!isa<BaseMemRefType>(result.getType())) 139*38bef476SMatthias Springer continue; 140*38bef476SMatthias Springer registerDependencies({operand}, {result}); 141*38bef476SMatthias Springer } 142*38bef476SMatthias Springer } 143*38bef476SMatthias Springer return WalkResult::advance(); 14423dec4a3SJohannes Reifferscheid }); 14523dec4a3SJohannes Reifferscheid } 146