xref: /llvm-project/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp (revision 0a20ab908ca7cc82a4c206d39d0eaf86a46e1ff0)
1 //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
12 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 namespace mlir {
18 namespace ml_program {
19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21 
22 namespace {
23 
24 class MLProgramPipelineGlobals
25     : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
26 public:
27   void runOnOperation() override;
28 
29 private:
30   LogicalResult buildGlobalMap(ModuleOp op);
31 
32   void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33                     llvm::DenseSet<SymbolRefAttr> &symbolStore);
34 
35   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
36   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
37 };
38 
39 // Traverses upwards searchign for the operation mapped by the symbol.
40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41   for (auto *op = baseOp; op; op = op->getParentOp()) {
42     auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43     if (lookup)
44       return lookup;
45   }
46   return nullptr;
47 }
48 
49 // Builds map from a symbol to MLProgram global symbols loaded or stored
50 // during processing.
51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
52   llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
53   auto res = module->walk([&](Operation *op) {
54     if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55       auto callable = caller.getCallableForCallee();
56       // For now we do not know how to handle Value based tracing, so fail.
57       if (mlir::isa<Value>(callable)) {
58         return WalkResult::interrupt();
59       }
60 
61       auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62       auto *func = getFromSymbol(op, symbol);
63       callableMap[symbol] = func;
64     }
65     return WalkResult::advance();
66   });
67 
68   if (res.wasInterrupted()) {
69     return failure();
70   }
71 
72   // First grab all symbols loaded or stored by each function. This
73   // will not handle calls initially.
74   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
75   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
76   for (auto callable : callableMap) {
77     llvm::DenseSet<SymbolRefAttr> loadSymbols;
78     llvm::DenseSet<SymbolRefAttr> storeSymbols;
79 
80     callable.getSecond()->walk(
81         [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82 
83     callable.getSecond()->walk(
84         [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85 
86     opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87     opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88   }
89 
90   // For each callable function we find each global loaded/stored within the
91   // function or a nested called function. This includes recursion checking to
92   // avoid infinitely recursing.
93   for (auto callable : callableMap) {
94     SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95     llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96     llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97     llvm::DenseSet<SymbolRefAttr> loadSymbols;
98     llvm::DenseSet<SymbolRefAttr> storeSymbols;
99 
100     for (size_t i = 0; i < work.size(); ++i) {
101       callableMap[work[i]]->walk([&](CallOpInterface call) {
102         auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103         if (visited.insert(symbol).second)
104           work.push_back(symbol);
105       });
106 
107       for (auto load : opLoadSymbols[work[i]])
108         loadSymbols.insert(load);
109 
110       for (auto store : opStoreSymbols[work[i]])
111         storeSymbols.insert(store);
112     }
113 
114     loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
115     storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
116   }
117 
118   return success();
119 }
120 
121 // Process each operation in the block deleting unneeded loads / stores,
122 // recursing on subblocks and checking function calls.
123 void MLProgramPipelineGlobals::processBlock(
124     Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
125     llvm::DenseSet<SymbolRefAttr> &symbolStore) {
126 
127   llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
128   llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
129   llvm::SmallVector<Operation *> toDelete;
130   for (auto &op : block) {
131     // If this is a global load, remap to a previous value if known
132     // and delete this load. Remember that this value is the currently
133     // known load.
134     if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
135       auto ref = load.getGlobal();
136       symbolLoad.insert(ref);
137       if (previousLoads.contains(ref)) {
138         toDelete.push_back(&op);
139         load.getResult().replaceAllUsesWith(previousLoads[ref]);
140       } else {
141         previousLoads[ref] = load.getResult();
142       }
143       continue;
144     }
145 
146     // Delete a previous store if it exists and is not needed, update
147     // the most recent known value for this global ref.
148     if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
149       auto ref = store.getGlobal();
150       symbolStore.insert(ref);
151       auto it = previousStores.find(ref);
152       if (it != previousStores.end()) {
153         toDelete.push_back(it->getSecond());
154       }
155 
156       previousLoads[ref] = store.getValue();
157       previousStores[ref] = &op;
158       continue;
159     }
160 
161     // If a function is called, clear known values for loads/stores used by
162     // the function or its sub-functions.
163     if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
164       auto loadSymbols =
165           loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
166       auto storeSymbols =
167           storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
168 
169       for (auto sym : loadSymbols) {
170         previousStores.erase(sym);
171       }
172 
173       for (auto sym : storeSymbols) {
174         previousLoads.erase(sym);
175         previousStores.erase(sym);
176       }
177       continue;
178     }
179 
180     // If the op has sub-regions, recurse inside. We make no guarantees whether
181     // the recursion occurs.
182     llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
183     llvm::DenseSet<SymbolRefAttr> opSymbolStore;
184     for (auto &region : op.getRegions()) {
185       for (auto &block : region) {
186         processBlock(block, opSymbolLoad, opSymbolStore);
187       }
188     }
189 
190     // Update current state from the subblock.
191     for (auto change : opSymbolLoad) {
192       symbolLoad.insert(change);
193       previousStores.erase(change);
194     }
195 
196     for (auto change : opSymbolStore) {
197       symbolStore.insert(change);
198       previousLoads.erase(change);
199       previousStores.erase(change);
200     }
201   }
202 
203   for (auto *op : toDelete) {
204     op->erase();
205   }
206 }
207 
208 void MLProgramPipelineGlobals::runOnOperation() {
209   auto targetOp = getOperation();
210   if (failed(buildGlobalMap(targetOp))) {
211     return;
212   }
213 
214   for (auto &funcOp : *targetOp.getBody()) {
215     for (auto &region : funcOp.getRegions()) {
216       for (auto &block : region.getBlocks()) {
217         llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
218         llvm::DenseSet<SymbolRefAttr> symbolsStored;
219         processBlock(block, symbolsLoaded, symbolsStored);
220       }
221     }
222   }
223 }
224 
225 } // namespace
226 
227 std::unique_ptr<OperationPass<mlir::ModuleOp>>
228 createMLProgramPipelineGlobalsPass() {
229   return std::make_unique<MLProgramPipelineGlobals>();
230 }
231 
232 } // namespace ml_program
233 } // namespace mlir
234