xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (revision 10ae8ae8375d6b69064204338a33500917749da9)
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10 
11 #include "mlir/Interfaces/ControlFlowInterfaces.h"
12 #include "mlir/Interfaces/ViewLikeInterface.h"
13 #include "llvm/ADT/SetOperations.h"
14 #include "llvm/ADT/SetVector.h"
15 
16 using namespace mlir;
17 
18 /// Constructs a new alias analysis using the op provided.
19 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
20 
21 /// Find all immediate and indirect dependent buffers this value could
22 /// potentially have. Note that the resulting set will also contain the value
23 /// provided as it is a dependent alias of itself.
24 BufferViewFlowAnalysis::ValueSetT
25 BufferViewFlowAnalysis::resolve(Value rootValue) const {
26   ValueSetT result;
27   SmallVector<Value, 8> queue;
28   queue.push_back(rootValue);
29   while (!queue.empty()) {
30     Value currentValue = queue.pop_back_val();
31     if (result.insert(currentValue).second) {
32       auto it = dependencies.find(currentValue);
33       if (it != dependencies.end()) {
34         for (Value aliasValue : it->second)
35           queue.push_back(aliasValue);
36       }
37     }
38   }
39   return result;
40 }
41 
42 /// Removes the given values from all alias sets.
43 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
44   for (auto &entry : dependencies)
45     llvm::set_subtract(entry.second, aliasValues);
46 }
47 
48 void BufferViewFlowAnalysis::rename(Value from, Value to) {
49   dependencies[to] = dependencies[from];
50   dependencies.erase(from);
51 
52   for (auto &[key, value] : dependencies) {
53     if (value.contains(from)) {
54       value.insert(to);
55       value.erase(from);
56     }
57   }
58 }
59 
60 /// This function constructs a mapping from values to its immediate
61 /// dependencies. It iterates over all blocks, gets their predecessors,
62 /// determines the values that will be passed to the corresponding block
63 /// arguments and inserts them into the underlying map. Furthermore, it wires
64 /// successor regions and branch-like return operations from nested regions.
65 void BufferViewFlowAnalysis::build(Operation *op) {
66   // Registers all dependencies of the given values.
67   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
68     for (auto [value, dep] : llvm::zip(values, dependencies))
69       this->dependencies[value].insert(dep);
70   };
71 
72   op->walk([&](Operation *op) {
73     // TODO: We should have an op interface instead of a hard-coded list of
74     // interfaces/ops.
75 
76     // Add additional dependencies created by view changes to the alias list.
77     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
78       dependencies[viewInterface.getViewSource()].insert(
79           viewInterface->getResult(0));
80       return WalkResult::advance();
81     }
82 
83     if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
84       // Query all branch interfaces to link block argument dependencies.
85       Block *parentBlock = branchInterface->getBlock();
86       for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
87            it != e; ++it) {
88         // Query the branch op interface to get the successor operands.
89         auto successorOperands =
90             branchInterface.getSuccessorOperands(it.getIndex());
91         // Build the actual mapping of values to their immediate dependencies.
92         registerDependencies(successorOperands.getForwardedOperands(),
93                              (*it)->getArguments().drop_front(
94                                  successorOperands.getProducedOperandCount()));
95       }
96       return WalkResult::advance();
97     }
98 
99     if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
100       // Query the RegionBranchOpInterface to find potential successor regions.
101       // Extract all entry regions and wire all initial entry successor inputs.
102       SmallVector<RegionSuccessor, 2> entrySuccessors;
103       regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
104                                           entrySuccessors);
105       for (RegionSuccessor &entrySuccessor : entrySuccessors) {
106         // Wire the entry region's successor arguments with the initial
107         // successor inputs.
108         registerDependencies(
109             regionInterface.getSuccessorEntryOperands(
110                 entrySuccessor.isParent()
111                     ? std::optional<unsigned>()
112                     : entrySuccessor.getSuccessor()->getRegionNumber()),
113             entrySuccessor.getSuccessorInputs());
114       }
115 
116       // Wire flow between regions and from region exits.
117       for (Region &region : regionInterface->getRegions()) {
118         // Iterate over all successor region entries that are reachable from the
119         // current region.
120         SmallVector<RegionSuccessor, 2> successorRegions;
121         regionInterface.getSuccessorRegions(region.getRegionNumber(),
122                                             successorRegions);
123         for (RegionSuccessor &successorRegion : successorRegions) {
124           // Determine the current region index (if any).
125           std::optional<unsigned> regionIndex;
126           Region *regionSuccessor = successorRegion.getSuccessor();
127           if (regionSuccessor)
128             regionIndex = regionSuccessor->getRegionNumber();
129           // Iterate over all immediate terminator operations and wire the
130           // successor inputs with the successor operands of each terminator.
131           for (Block &block : region)
132             if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
133                     block.getTerminator()))
134               registerDependencies(terminator.getSuccessorOperands(regionIndex),
135                                    successorRegion.getSuccessorInputs());
136         }
137       }
138 
139       return WalkResult::advance();
140     }
141 
142     // Unknown op: Assume that all operands alias with all results.
143     for (Value operand : op->getOperands()) {
144       if (!isa<BaseMemRefType>(operand.getType()))
145         continue;
146       for (Value result : op->getResults()) {
147         if (!isa<BaseMemRefType>(result.getType()))
148           continue;
149         registerDependencies({operand}, {result});
150       }
151     }
152     return WalkResult::advance();
153   });
154 }
155