xref: /llvm-project/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp (revision 0a20ab908ca7cc82a4c206d39d0eaf86a46e1ff0)
1cbd47504SRob Suderman //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2cbd47504SRob Suderman //
3cbd47504SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cbd47504SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5cbd47504SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cbd47504SRob Suderman //
7cbd47504SRob Suderman //===----------------------------------------------------------------------===//
8cbd47504SRob Suderman 
9cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
10cbd47504SRob Suderman 
11cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
12cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
13cbd47504SRob Suderman #include "mlir/IR/BuiltinOps.h"
14cbd47504SRob Suderman #include "mlir/Pass/Pass.h"
15cbd47504SRob Suderman #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16cbd47504SRob Suderman 
17cbd47504SRob Suderman namespace mlir {
18cbd47504SRob Suderman namespace ml_program {
19cbd47504SRob Suderman #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21cbd47504SRob Suderman 
22cbd47504SRob Suderman namespace {
23cbd47504SRob Suderman 
24cbd47504SRob Suderman class MLProgramPipelineGlobals
25cbd47504SRob Suderman     : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
26cbd47504SRob Suderman public:
27cbd47504SRob Suderman   void runOnOperation() override;
28cbd47504SRob Suderman 
29cbd47504SRob Suderman private:
30cbd47504SRob Suderman   LogicalResult buildGlobalMap(ModuleOp op);
31cbd47504SRob Suderman 
324e713357SMehdi Amini   void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33cbd47504SRob Suderman                     llvm::DenseSet<SymbolRefAttr> &symbolStore);
34cbd47504SRob Suderman 
35cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
36cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
37cbd47504SRob Suderman };
38cbd47504SRob Suderman 
39cbd47504SRob Suderman // Traverses upwards searchign for the operation mapped by the symbol.
40cbd47504SRob Suderman static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
419cebb285SMehdi Amini   for (auto *op = baseOp; op; op = op->getParentOp()) {
429cebb285SMehdi Amini     auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43cbd47504SRob Suderman     if (lookup)
44cbd47504SRob Suderman       return lookup;
45cbd47504SRob Suderman   }
46cbd47504SRob Suderman   return nullptr;
47cbd47504SRob Suderman }
48cbd47504SRob Suderman 
49cbd47504SRob Suderman // Builds map from a symbol to MLProgram global symbols loaded or stored
50cbd47504SRob Suderman // during processing.
51cbd47504SRob Suderman LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
52cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
53cbd47504SRob Suderman   auto res = module->walk([&](Operation *op) {
54cbd47504SRob Suderman     if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55cbd47504SRob Suderman       auto callable = caller.getCallableForCallee();
56cbd47504SRob Suderman       // For now we do not know how to handle Value based tracing, so fail.
57cbd47504SRob Suderman       if (mlir::isa<Value>(callable)) {
58cbd47504SRob Suderman         return WalkResult::interrupt();
59cbd47504SRob Suderman       }
60cbd47504SRob Suderman 
61cbd47504SRob Suderman       auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
629cebb285SMehdi Amini       auto *func = getFromSymbol(op, symbol);
63cbd47504SRob Suderman       callableMap[symbol] = func;
64cbd47504SRob Suderman     }
65cbd47504SRob Suderman     return WalkResult::advance();
66cbd47504SRob Suderman   });
67cbd47504SRob Suderman 
68cbd47504SRob Suderman   if (res.wasInterrupted()) {
69cbd47504SRob Suderman     return failure();
70cbd47504SRob Suderman   }
71cbd47504SRob Suderman 
72cbd47504SRob Suderman   // First grab all symbols loaded or stored by each function. This
73cbd47504SRob Suderman   // will not handle calls initially.
74cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
75cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
76cbd47504SRob Suderman   for (auto callable : callableMap) {
77cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> loadSymbols;
78cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> storeSymbols;
79cbd47504SRob Suderman 
80cbd47504SRob Suderman     callable.getSecond()->walk(
81cbd47504SRob Suderman         [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82cbd47504SRob Suderman 
83cbd47504SRob Suderman     callable.getSecond()->walk(
84cbd47504SRob Suderman         [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85cbd47504SRob Suderman 
86cbd47504SRob Suderman     opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87cbd47504SRob Suderman     opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88cbd47504SRob Suderman   }
89cbd47504SRob Suderman 
90cbd47504SRob Suderman   // For each callable function we find each global loaded/stored within the
91cbd47504SRob Suderman   // function or a nested called function. This includes recursion checking to
92cbd47504SRob Suderman   // avoid infinitely recursing.
93cbd47504SRob Suderman   for (auto callable : callableMap) {
94cbd47504SRob Suderman     SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95cbd47504SRob Suderman     llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> loadSymbols;
98cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> storeSymbols;
99cbd47504SRob Suderman 
100cbd47504SRob Suderman     for (size_t i = 0; i < work.size(); ++i) {
101cbd47504SRob Suderman       callableMap[work[i]]->walk([&](CallOpInterface call) {
102cbd47504SRob Suderman         auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
10371a39ecaSKazu Hirata         if (visited.insert(symbol).second)
104cbd47504SRob Suderman           work.push_back(symbol);
105cbd47504SRob Suderman       });
106cbd47504SRob Suderman 
107cbd47504SRob Suderman       for (auto load : opLoadSymbols[work[i]])
108cbd47504SRob Suderman         loadSymbols.insert(load);
109cbd47504SRob Suderman 
110cbd47504SRob Suderman       for (auto store : opStoreSymbols[work[i]])
111cbd47504SRob Suderman         storeSymbols.insert(store);
112cbd47504SRob Suderman     }
113cbd47504SRob Suderman 
114cbd47504SRob Suderman     loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
115cbd47504SRob Suderman     storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
116cbd47504SRob Suderman   }
117cbd47504SRob Suderman 
118cbd47504SRob Suderman   return success();
119cbd47504SRob Suderman }
120cbd47504SRob Suderman 
121cbd47504SRob Suderman // Process each operation in the block deleting unneeded loads / stores,
122cbd47504SRob Suderman // recursing on subblocks and checking function calls.
1234e713357SMehdi Amini void MLProgramPipelineGlobals::processBlock(
124cbd47504SRob Suderman     Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
125cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> &symbolStore) {
126cbd47504SRob Suderman 
127cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
128cbd47504SRob Suderman   llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
129cbd47504SRob Suderman   llvm::SmallVector<Operation *> toDelete;
130cbd47504SRob Suderman   for (auto &op : block) {
131cbd47504SRob Suderman     // If this is a global load, remap to a previous value if known
132cbd47504SRob Suderman     // and delete this load. Remember that this value is the currently
133cbd47504SRob Suderman     // known load.
134cbd47504SRob Suderman     if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
135cbd47504SRob Suderman       auto ref = load.getGlobal();
136cbd47504SRob Suderman       symbolLoad.insert(ref);
137cbd47504SRob Suderman       if (previousLoads.contains(ref)) {
138cbd47504SRob Suderman         toDelete.push_back(&op);
139cbd47504SRob Suderman         load.getResult().replaceAllUsesWith(previousLoads[ref]);
140cbd47504SRob Suderman       } else {
141cbd47504SRob Suderman         previousLoads[ref] = load.getResult();
142cbd47504SRob Suderman       }
143cbd47504SRob Suderman       continue;
144cbd47504SRob Suderman     }
145cbd47504SRob Suderman 
146cbd47504SRob Suderman     // Delete a previous store if it exists and is not needed, update
147cbd47504SRob Suderman     // the most recent known value for this global ref.
148cbd47504SRob Suderman     if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
149cbd47504SRob Suderman       auto ref = store.getGlobal();
150cbd47504SRob Suderman       symbolStore.insert(ref);
151*0a20ab90SKazu Hirata       auto it = previousStores.find(ref);
152*0a20ab90SKazu Hirata       if (it != previousStores.end()) {
153*0a20ab90SKazu Hirata         toDelete.push_back(it->getSecond());
154cbd47504SRob Suderman       }
155cbd47504SRob Suderman 
156cbd47504SRob Suderman       previousLoads[ref] = store.getValue();
157cbd47504SRob Suderman       previousStores[ref] = &op;
158cbd47504SRob Suderman       continue;
159cbd47504SRob Suderman     }
160cbd47504SRob Suderman 
161cbd47504SRob Suderman     // If a function is called, clear known values for loads/stores used by
162cbd47504SRob Suderman     // the function or its sub-functions.
163cbd47504SRob Suderman     if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
164cbd47504SRob Suderman       auto loadSymbols =
165cbd47504SRob Suderman           loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
166cbd47504SRob Suderman       auto storeSymbols =
167cbd47504SRob Suderman           storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
168cbd47504SRob Suderman 
169cbd47504SRob Suderman       for (auto sym : loadSymbols) {
170cbd47504SRob Suderman         previousStores.erase(sym);
171cbd47504SRob Suderman       }
172cbd47504SRob Suderman 
173cbd47504SRob Suderman       for (auto sym : storeSymbols) {
174cbd47504SRob Suderman         previousLoads.erase(sym);
175cbd47504SRob Suderman         previousStores.erase(sym);
176cbd47504SRob Suderman       }
177cbd47504SRob Suderman       continue;
178cbd47504SRob Suderman     }
179cbd47504SRob Suderman 
180cbd47504SRob Suderman     // If the op has sub-regions, recurse inside. We make no guarantees whether
181cbd47504SRob Suderman     // the recursion occurs.
182cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
183cbd47504SRob Suderman     llvm::DenseSet<SymbolRefAttr> opSymbolStore;
184cbd47504SRob Suderman     for (auto &region : op.getRegions()) {
185cbd47504SRob Suderman       for (auto &block : region) {
1864e713357SMehdi Amini         processBlock(block, opSymbolLoad, opSymbolStore);
187cbd47504SRob Suderman       }
188cbd47504SRob Suderman     }
189cbd47504SRob Suderman 
190cbd47504SRob Suderman     // Update current state from the subblock.
191cbd47504SRob Suderman     for (auto change : opSymbolLoad) {
192cbd47504SRob Suderman       symbolLoad.insert(change);
193cbd47504SRob Suderman       previousStores.erase(change);
194cbd47504SRob Suderman     }
195cbd47504SRob Suderman 
196cbd47504SRob Suderman     for (auto change : opSymbolStore) {
197cbd47504SRob Suderman       symbolStore.insert(change);
198cbd47504SRob Suderman       previousLoads.erase(change);
199cbd47504SRob Suderman       previousStores.erase(change);
200cbd47504SRob Suderman     }
201cbd47504SRob Suderman   }
202cbd47504SRob Suderman 
2039cebb285SMehdi Amini   for (auto *op : toDelete) {
204cbd47504SRob Suderman     op->erase();
205cbd47504SRob Suderman   }
206cbd47504SRob Suderman }
207cbd47504SRob Suderman 
208cbd47504SRob Suderman void MLProgramPipelineGlobals::runOnOperation() {
209cbd47504SRob Suderman   auto targetOp = getOperation();
210cbd47504SRob Suderman   if (failed(buildGlobalMap(targetOp))) {
211cbd47504SRob Suderman     return;
212cbd47504SRob Suderman   }
213cbd47504SRob Suderman 
214cbd47504SRob Suderman   for (auto &funcOp : *targetOp.getBody()) {
215cbd47504SRob Suderman     for (auto &region : funcOp.getRegions()) {
216cbd47504SRob Suderman       for (auto &block : region.getBlocks()) {
217cbd47504SRob Suderman         llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
218cbd47504SRob Suderman         llvm::DenseSet<SymbolRefAttr> symbolsStored;
2194e713357SMehdi Amini         processBlock(block, symbolsLoaded, symbolsStored);
220cbd47504SRob Suderman       }
221cbd47504SRob Suderman     }
222cbd47504SRob Suderman   }
223cbd47504SRob Suderman }
224cbd47504SRob Suderman 
225cbd47504SRob Suderman } // namespace
226cbd47504SRob Suderman 
227cbd47504SRob Suderman std::unique_ptr<OperationPass<mlir::ModuleOp>>
228cbd47504SRob Suderman createMLProgramPipelineGlobalsPass() {
229cbd47504SRob Suderman   return std::make_unique<MLProgramPipelineGlobals>();
230cbd47504SRob Suderman }
231cbd47504SRob Suderman 
232cbd47504SRob Suderman } // namespace ml_program
233cbd47504SRob Suderman } // namespace mlir
234