xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (revision dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6)
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
12 #include "mlir/Interfaces/CallInterfaces.h"
13 #include "mlir/Interfaces/ControlFlowInterfaces.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 #include "mlir/Interfaces/ViewLikeInterface.h"
16 #include "llvm/ADT/SetOperations.h"
17 #include "llvm/ADT/SetVector.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // BufferViewFlowAnalysis
24 //===----------------------------------------------------------------------===//
25 
26 /// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis(Operation * op)27 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
28 
29 static BufferViewFlowAnalysis::ValueSetT
resolveValues(const BufferViewFlowAnalysis::ValueMapT & map,Value value)30 resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
31   BufferViewFlowAnalysis::ValueSetT result;
32   SmallVector<Value, 8> queue;
33   queue.push_back(value);
34   while (!queue.empty()) {
35     Value currentValue = queue.pop_back_val();
36     if (result.insert(currentValue).second) {
37       auto it = map.find(currentValue);
38       if (it != map.end()) {
39         for (Value aliasValue : it->second)
40           queue.push_back(aliasValue);
41       }
42     }
43   }
44   return result;
45 }
46 
47 /// Find all immediate and indirect dependent buffers this value could
48 /// potentially have. Note that the resulting set will also contain the value
49 /// provided as it is a dependent alias of itself.
50 BufferViewFlowAnalysis::ValueSetT
resolve(Value rootValue) const51 BufferViewFlowAnalysis::resolve(Value rootValue) const {
52   return resolveValues(dependencies, rootValue);
53 }
54 
55 BufferViewFlowAnalysis::ValueSetT
resolveReverse(Value rootValue) const56 BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
57   return resolveValues(reverseDependencies, rootValue);
58 }
59 
60 /// Removes the given values from all alias sets.
remove(const SetVector<Value> & aliasValues)61 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
62   for (auto &entry : dependencies)
63     llvm::set_subtract(entry.second, aliasValues);
64 }
65 
rename(Value from,Value to)66 void BufferViewFlowAnalysis::rename(Value from, Value to) {
67   dependencies[to] = dependencies[from];
68   dependencies.erase(from);
69 
70   for (auto &[_, value] : dependencies) {
71     if (value.contains(from)) {
72       value.insert(to);
73       value.erase(from);
74     }
75   }
76 }
77 
78 /// This function constructs a mapping from values to its immediate
79 /// dependencies. It iterates over all blocks, gets their predecessors,
80 /// determines the values that will be passed to the corresponding block
81 /// arguments and inserts them into the underlying map. Furthermore, it wires
82 /// successor regions and branch-like return operations from nested regions.
build(Operation * op)83 void BufferViewFlowAnalysis::build(Operation *op) {
84   // Registers all dependencies of the given values.
85   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
86     for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
87       this->dependencies[value].insert(dep);
88       this->reverseDependencies[dep].insert(value);
89     }
90   };
91 
92   // Mark all buffer results and buffer region entry block arguments of the
93   // given op as terminals.
94   auto populateTerminalValues = [&](Operation *op) {
95     for (Value v : op->getResults())
96       if (isa<BaseMemRefType>(v.getType()))
97         this->terminals.insert(v);
98     for (Region &r : op->getRegions())
99       for (BlockArgument v : r.getArguments())
100         if (isa<BaseMemRefType>(v.getType()))
101           this->terminals.insert(v);
102   };
103 
104   op->walk([&](Operation *op) {
105     // Query BufferViewFlowOpInterface. If the op does not implement that
106     // interface, try to infer the dependencies from other interfaces that the
107     // op may implement.
108     if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
109       bufferViewFlowOp.populateDependencies(registerDependencies);
110       for (Value v : op->getResults())
111         if (isa<BaseMemRefType>(v.getType()) &&
112             bufferViewFlowOp.mayBeTerminalBuffer(v))
113           this->terminals.insert(v);
114       for (Region &r : op->getRegions())
115         for (BlockArgument v : r.getArguments())
116           if (isa<BaseMemRefType>(v.getType()) &&
117               bufferViewFlowOp.mayBeTerminalBuffer(v))
118             this->terminals.insert(v);
119       return WalkResult::advance();
120     }
121 
122     // Add additional dependencies created by view changes to the alias list.
123     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
124       registerDependencies(viewInterface.getViewSource(),
125                            viewInterface->getResult(0));
126       return WalkResult::advance();
127     }
128 
129     if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
130       // Query all branch interfaces to link block argument dependencies.
131       Block *parentBlock = branchInterface->getBlock();
132       for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
133            it != e; ++it) {
134         // Query the branch op interface to get the successor operands.
135         auto successorOperands =
136             branchInterface.getSuccessorOperands(it.getIndex());
137         // Build the actual mapping of values to their immediate dependencies.
138         registerDependencies(successorOperands.getForwardedOperands(),
139                              (*it)->getArguments().drop_front(
140                                  successorOperands.getProducedOperandCount()));
141       }
142       return WalkResult::advance();
143     }
144 
145     if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
146       // Query the RegionBranchOpInterface to find potential successor regions.
147       // Extract all entry regions and wire all initial entry successor inputs.
148       SmallVector<RegionSuccessor, 2> entrySuccessors;
149       regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
150                                           entrySuccessors);
151       for (RegionSuccessor &entrySuccessor : entrySuccessors) {
152         // Wire the entry region's successor arguments with the initial
153         // successor inputs.
154         registerDependencies(
155             regionInterface.getEntrySuccessorOperands(entrySuccessor),
156             entrySuccessor.getSuccessorInputs());
157       }
158 
159       // Wire flow between regions and from region exits.
160       for (Region &region : regionInterface->getRegions()) {
161         // Iterate over all successor region entries that are reachable from the
162         // current region.
163         SmallVector<RegionSuccessor, 2> successorRegions;
164         regionInterface.getSuccessorRegions(region, successorRegions);
165         for (RegionSuccessor &successorRegion : successorRegions) {
166           // Iterate over all immediate terminator operations and wire the
167           // successor inputs with the successor operands of each terminator.
168           for (Block &block : region)
169             if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
170                     block.getTerminator()))
171               registerDependencies(
172                   terminator.getSuccessorOperands(successorRegion),
173                   successorRegion.getSuccessorInputs());
174         }
175       }
176 
177       return WalkResult::advance();
178     }
179 
180     // Region terminators are handled together with RegionBranchOpInterface.
181     if (isa<RegionBranchTerminatorOpInterface>(op))
182       return WalkResult::advance();
183 
184     if (isa<CallOpInterface>(op)) {
185       // This is an intra-function analysis. We have no information about other
186       // functions. Conservatively assume that each operand may alias with each
187       // result. Also mark the results are terminals because the function could
188       // return newly allocated buffers.
189       populateTerminalValues(op);
190       for (Value operand : op->getOperands())
191         for (Value result : op->getResults())
192           registerDependencies({operand}, {result});
193       return WalkResult::advance();
194     }
195 
196     // We have no information about unknown ops.
197     populateTerminalValues(op);
198 
199     return WalkResult::advance();
200   });
201 }
202 
mayBeTerminalBuffer(Value value) const203 bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
204   assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
205   return terminals.contains(value);
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // BufferOriginAnalysis
210 //===----------------------------------------------------------------------===//
211 
212 /// Return "true" if the given value is the result of a memory allocation.
hasAllocateSideEffect(Value v)213 static bool hasAllocateSideEffect(Value v) {
214   Operation *op = v.getDefiningOp();
215   if (!op)
216     return false;
217   return hasEffect<MemoryEffects::Allocate>(op, v);
218 }
219 
220 /// Return "true" if the given value is a function block argument.
isFunctionArgument(Value v)221 static bool isFunctionArgument(Value v) {
222   auto bbArg = dyn_cast<BlockArgument>(v);
223   if (!bbArg)
224     return false;
225   Block *b = bbArg.getOwner();
226   auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
227   if (!funcOp)
228     return false;
229   return bbArg.getOwner() == &funcOp.getFunctionBody().front();
230 }
231 
232 /// Given a memref value, return the "base" value by skipping over all
233 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
getViewBase(Value value)234 static Value getViewBase(Value value) {
235   while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
236     value = viewLikeOp.getViewSource();
237   return value;
238 }
239 
BufferOriginAnalysis(Operation * op)240 BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
241 
isSameAllocation(Value v1,Value v2)242 std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
243   assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
244   assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
245 
246   // Skip over all view-like ops.
247   v1 = getViewBase(v1);
248   v2 = getViewBase(v2);
249 
250   // Fast path: If both buffers are the same SSA value, we can be sure that
251   // they originate from the same allocation.
252   if (v1 == v2)
253     return true;
254 
255   // Compute the SSA values from which the buffers `v1` and `v2` originate.
256   SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
257   SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
258 
259   // Originating buffers are "terminal" if they could not be traced back any
260   // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
261   // - function block arguments
262   // - values defined by allocation ops such as "memref.alloc"
263   // - values defined by ops that are unknown to the buffer view flow analysis
264   // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
265   SmallPtrSet<Value, 16> terminal1, terminal2;
266 
267   // While gathering terminal buffers, keep track of whether all terminal
268   // buffers are newly allocated buffer or function entry arguments.
269   bool allAllocs1 = true, allAllocs2 = true;
270   bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
271 
272   // Helper function that gathers terminal buffers among `origin`.
273   auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
274                                       SmallPtrSet<Value, 16> &terminal,
275                                       bool &allAllocs,
276                                       bool &allAllocsOrFuncEntryArgs) {
277     for (Value v : origin) {
278       if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
279         terminal.insert(v);
280         allAllocs &= hasAllocateSideEffect(v);
281         allAllocsOrFuncEntryArgs &=
282             isFunctionArgument(v) || hasAllocateSideEffect(v);
283       }
284     }
285     assert(!terminal.empty() && "expected non-empty terminal set");
286   };
287 
288   // Gather terminal buffers for `v1` and `v2`.
289   gatherTerminalBuffers(origin1, terminal1, allAllocs1,
290                         allAllocsOrFuncEntryArgs1);
291   gatherTerminalBuffers(origin2, terminal2, allAllocs2,
292                         allAllocsOrFuncEntryArgs2);
293 
294   // If both `v1` and `v2` have a single matching terminal buffer, they are
295   // guaranteed to originate from the same buffer allocation.
296   if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
297       *terminal1.begin() == *terminal2.begin())
298     return true;
299 
300   // At least one of the two values has multiple terminals.
301 
302   // Check if there is overlap between the terminal buffers of `v1` and `v2`.
303   bool distinctTerminalSets = true;
304   for (Value v : terminal1)
305     distinctTerminalSets &= !terminal2.contains(v);
306   // If there is overlap between the terminal buffers of `v1` and `v2`, we
307   // cannot make an accurate decision without further analysis.
308   if (!distinctTerminalSets)
309     return std::nullopt;
310 
311   // If `v1` originates from only allocs, and `v2` is guaranteed to originate
312   // from different allocations (that is guaranteed if `v2` originates from
313   // only distinct allocs or function entry arguments), we can be sure that
314   // `v1` and `v2` originate from different allocations. The same argument can
315   // be made when swapping `v1` and `v2`.
316   bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
317   bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
318   if (isolatedAlloc1 || isolatedAlloc2)
319     return false;
320 
321   // Otherwise: We do not know whether `v1` and `v2` originate from the same
322   // allocation or not.
323   // TODO: Function arguments are currently handled conservatively. We assume
324   // that they could be the same allocation.
325   // TODO: Terminals other than allocations and function arguments are
326   // currently handled conservatively. We assume that they could be the same
327   // allocation. E.g., we currently return "nullopt" for values that originate
328   // from different "memref.get_global" ops (with different symbols).
329   return std::nullopt;
330 }
331