xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (revision 38bef476552021b7ad45d1aa989d250bcd0a38ff)
123dec4a3SJohannes Reifferscheid //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
223dec4a3SJohannes Reifferscheid //
323dec4a3SJohannes Reifferscheid // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
423dec4a3SJohannes Reifferscheid // See https://llvm.org/LICENSE.txt for license information.
523dec4a3SJohannes Reifferscheid // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
623dec4a3SJohannes Reifferscheid //
723dec4a3SJohannes Reifferscheid //===----------------------------------------------------------------------===//
823dec4a3SJohannes Reifferscheid 
923dec4a3SJohannes Reifferscheid #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
1023dec4a3SJohannes Reifferscheid 
1123dec4a3SJohannes Reifferscheid #include "mlir/Interfaces/ControlFlowInterfaces.h"
1223dec4a3SJohannes Reifferscheid #include "mlir/Interfaces/ViewLikeInterface.h"
1323dec4a3SJohannes Reifferscheid #include "llvm/ADT/SetOperations.h"
14e6edc1bdSUday Bondhugula #include "llvm/ADT/SetVector.h"
1523dec4a3SJohannes Reifferscheid 
1623dec4a3SJohannes Reifferscheid using namespace mlir;
1723dec4a3SJohannes Reifferscheid 
1823dec4a3SJohannes Reifferscheid /// Constructs a new alias analysis using the op provided.
1923dec4a3SJohannes Reifferscheid BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
2023dec4a3SJohannes Reifferscheid 
2123dec4a3SJohannes Reifferscheid /// Find all immediate and indirect dependent buffers this value could
2223dec4a3SJohannes Reifferscheid /// potentially have. Note that the resulting set will also contain the value
2323dec4a3SJohannes Reifferscheid /// provided as it is a dependent alias of itself.
2423dec4a3SJohannes Reifferscheid BufferViewFlowAnalysis::ValueSetT
2523dec4a3SJohannes Reifferscheid BufferViewFlowAnalysis::resolve(Value rootValue) const {
2623dec4a3SJohannes Reifferscheid   ValueSetT result;
2723dec4a3SJohannes Reifferscheid   SmallVector<Value, 8> queue;
2823dec4a3SJohannes Reifferscheid   queue.push_back(rootValue);
2923dec4a3SJohannes Reifferscheid   while (!queue.empty()) {
3023dec4a3SJohannes Reifferscheid     Value currentValue = queue.pop_back_val();
3123dec4a3SJohannes Reifferscheid     if (result.insert(currentValue).second) {
3223dec4a3SJohannes Reifferscheid       auto it = dependencies.find(currentValue);
3323dec4a3SJohannes Reifferscheid       if (it != dependencies.end()) {
3423dec4a3SJohannes Reifferscheid         for (Value aliasValue : it->second)
3523dec4a3SJohannes Reifferscheid           queue.push_back(aliasValue);
3623dec4a3SJohannes Reifferscheid       }
3723dec4a3SJohannes Reifferscheid     }
3823dec4a3SJohannes Reifferscheid   }
3923dec4a3SJohannes Reifferscheid   return result;
4023dec4a3SJohannes Reifferscheid }
4123dec4a3SJohannes Reifferscheid 
4223dec4a3SJohannes Reifferscheid /// Removes the given values from all alias sets.
43e6edc1bdSUday Bondhugula void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
4423dec4a3SJohannes Reifferscheid   for (auto &entry : dependencies)
4523dec4a3SJohannes Reifferscheid     llvm::set_subtract(entry.second, aliasValues);
4623dec4a3SJohannes Reifferscheid }
4723dec4a3SJohannes Reifferscheid 
4823dec4a3SJohannes Reifferscheid /// This function constructs a mapping from values to its immediate
4923dec4a3SJohannes Reifferscheid /// dependencies. It iterates over all blocks, gets their predecessors,
5023dec4a3SJohannes Reifferscheid /// determines the values that will be passed to the corresponding block
5123dec4a3SJohannes Reifferscheid /// arguments and inserts them into the underlying map. Furthermore, it wires
5223dec4a3SJohannes Reifferscheid /// successor regions and branch-like return operations from nested regions.
5323dec4a3SJohannes Reifferscheid void BufferViewFlowAnalysis::build(Operation *op) {
5423dec4a3SJohannes Reifferscheid   // Registers all dependencies of the given values.
5523dec4a3SJohannes Reifferscheid   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
5623dec4a3SJohannes Reifferscheid     for (auto [value, dep] : llvm::zip(values, dependencies))
5723dec4a3SJohannes Reifferscheid       this->dependencies[value].insert(dep);
5823dec4a3SJohannes Reifferscheid   };
5923dec4a3SJohannes Reifferscheid 
60*38bef476SMatthias Springer   op->walk([&](Operation *op) {
61*38bef476SMatthias Springer     // TODO: We should have an op interface instead of a hard-coded list of
62*38bef476SMatthias Springer     // interfaces/ops.
63*38bef476SMatthias Springer 
6423dec4a3SJohannes Reifferscheid     // Add additional dependencies created by view changes to the alias list.
65*38bef476SMatthias Springer     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
6623dec4a3SJohannes Reifferscheid       dependencies[viewInterface.getViewSource()].insert(
6723dec4a3SJohannes Reifferscheid           viewInterface->getResult(0));
68*38bef476SMatthias Springer       return WalkResult::advance();
69*38bef476SMatthias Springer     }
7023dec4a3SJohannes Reifferscheid 
71*38bef476SMatthias Springer     if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
7223dec4a3SJohannes Reifferscheid       // Query all branch interfaces to link block argument dependencies.
7323dec4a3SJohannes Reifferscheid       Block *parentBlock = branchInterface->getBlock();
7423dec4a3SJohannes Reifferscheid       for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
7523dec4a3SJohannes Reifferscheid            it != e; ++it) {
7623dec4a3SJohannes Reifferscheid         // Query the branch op interface to get the successor operands.
7723dec4a3SJohannes Reifferscheid         auto successorOperands =
7823dec4a3SJohannes Reifferscheid             branchInterface.getSuccessorOperands(it.getIndex());
7923dec4a3SJohannes Reifferscheid         // Build the actual mapping of values to their immediate dependencies.
8023dec4a3SJohannes Reifferscheid         registerDependencies(successorOperands.getForwardedOperands(),
8123dec4a3SJohannes Reifferscheid                              (*it)->getArguments().drop_front(
8223dec4a3SJohannes Reifferscheid                                  successorOperands.getProducedOperandCount()));
8323dec4a3SJohannes Reifferscheid       }
84*38bef476SMatthias Springer       return WalkResult::advance();
85*38bef476SMatthias Springer     }
8623dec4a3SJohannes Reifferscheid 
87*38bef476SMatthias Springer     if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
8823dec4a3SJohannes Reifferscheid       // Query the RegionBranchOpInterface to find potential successor regions.
8923dec4a3SJohannes Reifferscheid       // Extract all entry regions and wire all initial entry successor inputs.
9023dec4a3SJohannes Reifferscheid       SmallVector<RegionSuccessor, 2> entrySuccessors;
911a36588eSKazu Hirata       regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
921a36588eSKazu Hirata                                           entrySuccessors);
9323dec4a3SJohannes Reifferscheid       for (RegionSuccessor &entrySuccessor : entrySuccessors) {
9423dec4a3SJohannes Reifferscheid         // Wire the entry region's successor arguments with the initial
9523dec4a3SJohannes Reifferscheid         // successor inputs.
9623dec4a3SJohannes Reifferscheid         assert(entrySuccessor.getSuccessor() &&
9723dec4a3SJohannes Reifferscheid                "Invalid entry region without an attached successor region");
9823dec4a3SJohannes Reifferscheid         registerDependencies(
9923dec4a3SJohannes Reifferscheid             regionInterface.getSuccessorEntryOperands(
10023dec4a3SJohannes Reifferscheid                 entrySuccessor.getSuccessor()->getRegionNumber()),
10123dec4a3SJohannes Reifferscheid             entrySuccessor.getSuccessorInputs());
10223dec4a3SJohannes Reifferscheid       }
10323dec4a3SJohannes Reifferscheid 
10423dec4a3SJohannes Reifferscheid       // Wire flow between regions and from region exits.
10523dec4a3SJohannes Reifferscheid       for (Region &region : regionInterface->getRegions()) {
10623dec4a3SJohannes Reifferscheid         // Iterate over all successor region entries that are reachable from the
10723dec4a3SJohannes Reifferscheid         // current region.
10823dec4a3SJohannes Reifferscheid         SmallVector<RegionSuccessor, 2> successorRegions;
10923dec4a3SJohannes Reifferscheid         regionInterface.getSuccessorRegions(region.getRegionNumber(),
11023dec4a3SJohannes Reifferscheid                                             successorRegions);
11123dec4a3SJohannes Reifferscheid         for (RegionSuccessor &successorRegion : successorRegions) {
11223dec4a3SJohannes Reifferscheid           // Determine the current region index (if any).
11322426110SRamkumar Ramachandra           std::optional<unsigned> regionIndex;
11423dec4a3SJohannes Reifferscheid           Region *regionSuccessor = successorRegion.getSuccessor();
11523dec4a3SJohannes Reifferscheid           if (regionSuccessor)
11623dec4a3SJohannes Reifferscheid             regionIndex = regionSuccessor->getRegionNumber();
11723dec4a3SJohannes Reifferscheid           // Iterate over all immediate terminator operations and wire the
11823dec4a3SJohannes Reifferscheid           // successor inputs with the successor operands of each terminator.
11923dec4a3SJohannes Reifferscheid           for (Block &block : region) {
12023dec4a3SJohannes Reifferscheid             auto successorOperands = getRegionBranchSuccessorOperands(
12123dec4a3SJohannes Reifferscheid                 block.getTerminator(), regionIndex);
12223dec4a3SJohannes Reifferscheid             if (successorOperands) {
12323dec4a3SJohannes Reifferscheid               registerDependencies(*successorOperands,
12423dec4a3SJohannes Reifferscheid                                    successorRegion.getSuccessorInputs());
12523dec4a3SJohannes Reifferscheid             }
12623dec4a3SJohannes Reifferscheid           }
12723dec4a3SJohannes Reifferscheid         }
12823dec4a3SJohannes Reifferscheid       }
12923dec4a3SJohannes Reifferscheid 
130*38bef476SMatthias Springer       return WalkResult::advance();
131*38bef476SMatthias Springer     }
132*38bef476SMatthias Springer 
133*38bef476SMatthias Springer     // Unknown op: Assume that all operands alias with all results.
134*38bef476SMatthias Springer     for (Value operand : op->getOperands()) {
135*38bef476SMatthias Springer       if (!isa<BaseMemRefType>(operand.getType()))
136*38bef476SMatthias Springer         continue;
137*38bef476SMatthias Springer       for (Value result : op->getResults()) {
138*38bef476SMatthias Springer         if (!isa<BaseMemRefType>(result.getType()))
139*38bef476SMatthias Springer           continue;
140*38bef476SMatthias Springer         registerDependencies({operand}, {result});
141*38bef476SMatthias Springer       }
142*38bef476SMatthias Springer     }
143*38bef476SMatthias Springer     return WalkResult::advance();
14423dec4a3SJohannes Reifferscheid   });
14523dec4a3SJohannes Reifferscheid }
146