xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (revision e6edc1bd69c881eabf78c439be7f42a639f0df79)
1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "mlir/Interfaces/ViewLikeInterface.h"
14 #include "llvm/ADT/SetOperations.h"
15 #include "llvm/ADT/SetVector.h"
16 
17 using namespace mlir;
18 
19 /// Constructs a new alias analysis using the op provided.
20 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
21 
22 /// Find all immediate and indirect dependent buffers this value could
23 /// potentially have. Note that the resulting set will also contain the value
24 /// provided as it is a dependent alias of itself.
25 BufferViewFlowAnalysis::ValueSetT
26 BufferViewFlowAnalysis::resolve(Value rootValue) const {
27   ValueSetT result;
28   SmallVector<Value, 8> queue;
29   queue.push_back(rootValue);
30   while (!queue.empty()) {
31     Value currentValue = queue.pop_back_val();
32     if (result.insert(currentValue).second) {
33       auto it = dependencies.find(currentValue);
34       if (it != dependencies.end()) {
35         for (Value aliasValue : it->second)
36           queue.push_back(aliasValue);
37       }
38     }
39   }
40   return result;
41 }
42 
43 /// Removes the given values from all alias sets.
44 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
45   for (auto &entry : dependencies)
46     llvm::set_subtract(entry.second, aliasValues);
47 }
48 
49 /// This function constructs a mapping from values to its immediate
50 /// dependencies. It iterates over all blocks, gets their predecessors,
51 /// determines the values that will be passed to the corresponding block
52 /// arguments and inserts them into the underlying map. Furthermore, it wires
53 /// successor regions and branch-like return operations from nested regions.
54 void BufferViewFlowAnalysis::build(Operation *op) {
55   // Registers all dependencies of the given values.
56   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
57     for (auto [value, dep] : llvm::zip(values, dependencies))
58       this->dependencies[value].insert(dep);
59   };
60 
61   // Add additional dependencies created by view changes to the alias list.
62   op->walk([&](ViewLikeOpInterface viewInterface) {
63     dependencies[viewInterface.getViewSource()].insert(
64         viewInterface->getResult(0));
65   });
66 
67   // Query all branch interfaces to link block argument dependencies.
68   op->walk([&](BranchOpInterface branchInterface) {
69     Block *parentBlock = branchInterface->getBlock();
70     for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
71          it != e; ++it) {
72       // Query the branch op interface to get the successor operands.
73       auto successorOperands =
74           branchInterface.getSuccessorOperands(it.getIndex());
75       // Build the actual mapping of values to their immediate dependencies.
76       registerDependencies(successorOperands.getForwardedOperands(),
77                            (*it)->getArguments().drop_front(
78                                successorOperands.getProducedOperandCount()));
79     }
80   });
81 
82   // Query the RegionBranchOpInterface to find potential successor regions.
83   op->walk([&](RegionBranchOpInterface regionInterface) {
84     // Extract all entry regions and wire all initial entry successor inputs.
85     SmallVector<RegionSuccessor, 2> entrySuccessors;
86     regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
87                                         entrySuccessors);
88     for (RegionSuccessor &entrySuccessor : entrySuccessors) {
89       // Wire the entry region's successor arguments with the initial
90       // successor inputs.
91       assert(entrySuccessor.getSuccessor() &&
92              "Invalid entry region without an attached successor region");
93       registerDependencies(
94           regionInterface.getSuccessorEntryOperands(
95               entrySuccessor.getSuccessor()->getRegionNumber()),
96           entrySuccessor.getSuccessorInputs());
97     }
98 
99     // Wire flow between regions and from region exits.
100     for (Region &region : regionInterface->getRegions()) {
101       // Iterate over all successor region entries that are reachable from the
102       // current region.
103       SmallVector<RegionSuccessor, 2> successorRegions;
104       regionInterface.getSuccessorRegions(region.getRegionNumber(),
105                                           successorRegions);
106       for (RegionSuccessor &successorRegion : successorRegions) {
107         // Determine the current region index (if any).
108         std::optional<unsigned> regionIndex;
109         Region *regionSuccessor = successorRegion.getSuccessor();
110         if (regionSuccessor)
111           regionIndex = regionSuccessor->getRegionNumber();
112         // Iterate over all immediate terminator operations and wire the
113         // successor inputs with the successor operands of each terminator.
114         for (Block &block : region) {
115           auto successorOperands = getRegionBranchSuccessorOperands(
116               block.getTerminator(), regionIndex);
117           if (successorOperands) {
118             registerDependencies(*successorOperands,
119                                  successorRegion.getSuccessorInputs());
120           }
121         }
122       }
123     }
124   });
125 
126   // TODO: This should be an interface.
127   op->walk([&](arith::SelectOp selectOp) {
128     registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()});
129     registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()});
130   });
131 }
132