xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (revision a45e58af1b381cf3c0374332386b8291ec5310f4)
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
12 #include "mlir/Interfaces/CallInterfaces.h"
13 #include "mlir/Interfaces/ControlFlowInterfaces.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 #include "mlir/Interfaces/ViewLikeInterface.h"
16 #include "llvm/ADT/SetOperations.h"
17 #include "llvm/ADT/SetVector.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 /// Constructs a new alias analysis using the op provided.
23 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
24 
25 /// Find all immediate and indirect dependent buffers this value could
26 /// potentially have. Note that the resulting set will also contain the value
27 /// provided as it is a dependent alias of itself.
28 BufferViewFlowAnalysis::ValueSetT
29 BufferViewFlowAnalysis::resolve(Value rootValue) const {
30   ValueSetT result;
31   SmallVector<Value, 8> queue;
32   queue.push_back(rootValue);
33   while (!queue.empty()) {
34     Value currentValue = queue.pop_back_val();
35     if (result.insert(currentValue).second) {
36       auto it = dependencies.find(currentValue);
37       if (it != dependencies.end()) {
38         for (Value aliasValue : it->second)
39           queue.push_back(aliasValue);
40       }
41     }
42   }
43   return result;
44 }
45 
46 /// Removes the given values from all alias sets.
47 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
48   for (auto &entry : dependencies)
49     llvm::set_subtract(entry.second, aliasValues);
50 }
51 
52 void BufferViewFlowAnalysis::rename(Value from, Value to) {
53   dependencies[to] = dependencies[from];
54   dependencies.erase(from);
55 
56   for (auto &[_, value] : dependencies) {
57     if (value.contains(from)) {
58       value.insert(to);
59       value.erase(from);
60     }
61   }
62 }
63 
64 /// This function constructs a mapping from values to its immediate
65 /// dependencies. It iterates over all blocks, gets their predecessors,
66 /// determines the values that will be passed to the corresponding block
67 /// arguments and inserts them into the underlying map. Furthermore, it wires
68 /// successor regions and branch-like return operations from nested regions.
69 void BufferViewFlowAnalysis::build(Operation *op) {
70   // Registers all dependencies of the given values.
71   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
72     for (auto [value, dep] : llvm::zip_equal(values, dependencies))
73       this->dependencies[value].insert(dep);
74   };
75 
76   // Mark all buffer results and buffer region entry block arguments of the
77   // given op as terminals.
78   auto populateTerminalValues = [&](Operation *op) {
79     for (Value v : op->getResults())
80       if (isa<BaseMemRefType>(v.getType()))
81         this->terminals.insert(v);
82     for (Region &r : op->getRegions())
83       for (BlockArgument v : r.getArguments())
84         if (isa<BaseMemRefType>(v.getType()))
85           this->terminals.insert(v);
86   };
87 
88   op->walk([&](Operation *op) {
89     // Query BufferViewFlowOpInterface. If the op does not implement that
90     // interface, try to infer the dependencies from other interfaces that the
91     // op may implement.
92     if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
93       bufferViewFlowOp.populateDependencies(registerDependencies);
94       for (Value v : op->getResults())
95         if (isa<BaseMemRefType>(v.getType()) &&
96             bufferViewFlowOp.mayBeTerminalBuffer(v))
97           this->terminals.insert(v);
98       for (Region &r : op->getRegions())
99         for (BlockArgument v : r.getArguments())
100           if (isa<BaseMemRefType>(v.getType()) &&
101               bufferViewFlowOp.mayBeTerminalBuffer(v))
102             this->terminals.insert(v);
103       return WalkResult::advance();
104     }
105 
106     // Add additional dependencies created by view changes to the alias list.
107     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
108       registerDependencies(viewInterface.getViewSource(),
109                            viewInterface->getResult(0));
110       return WalkResult::advance();
111     }
112 
113     if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
114       // Query all branch interfaces to link block argument dependencies.
115       Block *parentBlock = branchInterface->getBlock();
116       for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
117            it != e; ++it) {
118         // Query the branch op interface to get the successor operands.
119         auto successorOperands =
120             branchInterface.getSuccessorOperands(it.getIndex());
121         // Build the actual mapping of values to their immediate dependencies.
122         registerDependencies(successorOperands.getForwardedOperands(),
123                              (*it)->getArguments().drop_front(
124                                  successorOperands.getProducedOperandCount()));
125       }
126       return WalkResult::advance();
127     }
128 
129     if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
130       // Query the RegionBranchOpInterface to find potential successor regions.
131       // Extract all entry regions and wire all initial entry successor inputs.
132       SmallVector<RegionSuccessor, 2> entrySuccessors;
133       regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
134                                           entrySuccessors);
135       for (RegionSuccessor &entrySuccessor : entrySuccessors) {
136         // Wire the entry region's successor arguments with the initial
137         // successor inputs.
138         registerDependencies(
139             regionInterface.getEntrySuccessorOperands(entrySuccessor),
140             entrySuccessor.getSuccessorInputs());
141       }
142 
143       // Wire flow between regions and from region exits.
144       for (Region &region : regionInterface->getRegions()) {
145         // Iterate over all successor region entries that are reachable from the
146         // current region.
147         SmallVector<RegionSuccessor, 2> successorRegions;
148         regionInterface.getSuccessorRegions(region, successorRegions);
149         for (RegionSuccessor &successorRegion : successorRegions) {
150           // Iterate over all immediate terminator operations and wire the
151           // successor inputs with the successor operands of each terminator.
152           for (Block &block : region)
153             if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
154                     block.getTerminator()))
155               registerDependencies(
156                   terminator.getSuccessorOperands(successorRegion),
157                   successorRegion.getSuccessorInputs());
158         }
159       }
160 
161       return WalkResult::advance();
162     }
163 
164     // Region terminators are handled together with RegionBranchOpInterface.
165     if (isa<RegionBranchTerminatorOpInterface>(op))
166       return WalkResult::advance();
167 
168     if (isa<CallOpInterface>(op)) {
169       // This is an intra-function analysis. We have no information about other
170       // functions. Conservatively assume that each operand may alias with each
171       // result. Also mark the results are terminals because the function could
172       // return newly allocated buffers.
173       populateTerminalValues(op);
174       for (Value operand : op->getOperands())
175         for (Value result : op->getResults())
176           registerDependencies({operand}, {result});
177       return WalkResult::advance();
178     }
179 
180     // We have no information about unknown ops.
181     populateTerminalValues(op);
182 
183     return WalkResult::advance();
184   });
185 }
186 
187 bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
188   assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
189   return terminals.contains(value);
190 }
191