xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (revision 1a36588ec64ae8576e531e6f0b49eadb90ab0b11)
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "mlir/Interfaces/ViewLikeInterface.h"
14 #include "llvm/ADT/SetOperations.h"
15 
16 using namespace mlir;
17 
18 /// Constructs a new alias analysis using the op provided.
19 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
20 
21 /// Find all immediate and indirect dependent buffers this value could
22 /// potentially have. Note that the resulting set will also contain the value
23 /// provided as it is a dependent alias of itself.
24 BufferViewFlowAnalysis::ValueSetT
25 BufferViewFlowAnalysis::resolve(Value rootValue) const {
26   ValueSetT result;
27   SmallVector<Value, 8> queue;
28   queue.push_back(rootValue);
29   while (!queue.empty()) {
30     Value currentValue = queue.pop_back_val();
31     if (result.insert(currentValue).second) {
32       auto it = dependencies.find(currentValue);
33       if (it != dependencies.end()) {
34         for (Value aliasValue : it->second)
35           queue.push_back(aliasValue);
36       }
37     }
38   }
39   return result;
40 }
41 
42 /// Removes the given values from all alias sets.
43 void BufferViewFlowAnalysis::remove(const SmallPtrSetImpl<Value> &aliasValues) {
44   for (auto &entry : dependencies)
45     llvm::set_subtract(entry.second, aliasValues);
46 }
47 
48 /// This function constructs a mapping from values to its immediate
49 /// dependencies. It iterates over all blocks, gets their predecessors,
50 /// determines the values that will be passed to the corresponding block
51 /// arguments and inserts them into the underlying map. Furthermore, it wires
52 /// successor regions and branch-like return operations from nested regions.
53 void BufferViewFlowAnalysis::build(Operation *op) {
54   // Registers all dependencies of the given values.
55   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
56     for (auto [value, dep] : llvm::zip(values, dependencies))
57       this->dependencies[value].insert(dep);
58   };
59 
60   // Add additional dependencies created by view changes to the alias list.
61   op->walk([&](ViewLikeOpInterface viewInterface) {
62     dependencies[viewInterface.getViewSource()].insert(
63         viewInterface->getResult(0));
64   });
65 
66   // Query all branch interfaces to link block argument dependencies.
67   op->walk([&](BranchOpInterface branchInterface) {
68     Block *parentBlock = branchInterface->getBlock();
69     for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
70          it != e; ++it) {
71       // Query the branch op interface to get the successor operands.
72       auto successorOperands =
73           branchInterface.getSuccessorOperands(it.getIndex());
74       // Build the actual mapping of values to their immediate dependencies.
75       registerDependencies(successorOperands.getForwardedOperands(),
76                            (*it)->getArguments().drop_front(
77                                successorOperands.getProducedOperandCount()));
78     }
79   });
80 
81   // Query the RegionBranchOpInterface to find potential successor regions.
82   op->walk([&](RegionBranchOpInterface regionInterface) {
83     // Extract all entry regions and wire all initial entry successor inputs.
84     SmallVector<RegionSuccessor, 2> entrySuccessors;
85     regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
86                                         entrySuccessors);
87     for (RegionSuccessor &entrySuccessor : entrySuccessors) {
88       // Wire the entry region's successor arguments with the initial
89       // successor inputs.
90       assert(entrySuccessor.getSuccessor() &&
91              "Invalid entry region without an attached successor region");
92       registerDependencies(
93           regionInterface.getSuccessorEntryOperands(
94               entrySuccessor.getSuccessor()->getRegionNumber()),
95           entrySuccessor.getSuccessorInputs());
96     }
97 
98     // Wire flow between regions and from region exits.
99     for (Region &region : regionInterface->getRegions()) {
100       // Iterate over all successor region entries that are reachable from the
101       // current region.
102       SmallVector<RegionSuccessor, 2> successorRegions;
103       regionInterface.getSuccessorRegions(region.getRegionNumber(),
104                                           successorRegions);
105       for (RegionSuccessor &successorRegion : successorRegions) {
106         // Determine the current region index (if any).
107         Optional<unsigned> regionIndex;
108         Region *regionSuccessor = successorRegion.getSuccessor();
109         if (regionSuccessor)
110           regionIndex = regionSuccessor->getRegionNumber();
111         // Iterate over all immediate terminator operations and wire the
112         // successor inputs with the successor operands of each terminator.
113         for (Block &block : region) {
114           auto successorOperands = getRegionBranchSuccessorOperands(
115               block.getTerminator(), regionIndex);
116           if (successorOperands) {
117             registerDependencies(*successorOperands,
118                                  successorRegion.getSuccessorInputs());
119           }
120         }
121       }
122     }
123   });
124 
125   // TODO: This should be an interface.
126   op->walk([&](arith::SelectOp selectOp) {
127     registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()});
128     registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()});
129   });
130 }
131